vllm.utils.deep_gemm ¶
Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import only these wrappers.
DeepGemmQuantScaleFMT ¶
Bases: Enum
Source code in vllm/utils/deep_gemm.py
from_oracle classmethod ¶
from_oracle() -> DeepGemmQuantScaleFMT
Return the pre-initialized oracle decision
Source code in vllm/utils/deep_gemm.py
init_oracle_cache classmethod ¶
Initialize the oracle decision and store it in the class cache
Source code in vllm/utils/deep_gemm.py
_import_deep_gemm ¶
Import the deep_gemm module.
Prefers an externally installed deep_gemm package (so users can pin a specific version), then falls back to the vendored copy bundled in the vLLM wheel.
Returns None when neither source is usable.
Source code in vllm/utils/deep_gemm.py
_lazy_init ¶
Import deep_gemm and resolve symbols on first use.
Source code in vllm/utils/deep_gemm.py
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 | |
_missing ¶
Placeholder for unavailable DeepGEMM backend.
Source code in vllm/utils/deep_gemm.py
calc_diff ¶
Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim. Once kernel accuracy improves this helper can be removed.
Source code in vllm/utils/deep_gemm.py
fp8_fp4_mqa_logits ¶
fp8_fp4_mqa_logits(
q: tuple[Tensor, Tensor | None],
kv: tuple[Tensor, Tensor],
weights: Tensor,
cu_seqlen_ks: Tensor,
cu_seqlen_ke: Tensor,
clean_logits: bool,
) -> Tensor
Compute MQA logits for a single sequence without KV paging.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes q = (values, scales_or_None) where scales is None for FP8 Q (per-token scale is folded into weights) and a packed block-scale tensor for MXFP4 Q.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q | tuple[Tensor, Tensor | None] | Tuple | required |
kv | tuple[Tensor, Tensor] | Tuple | required |
weights | Tensor | weights of shape [M, H], dtype | required |
cu_seqlen_ks | Tensor | Start indices (inclusive) for valid K per query position, shape [M], dtype int32. | required |
cu_seqlen_ke | Tensor | End indices (exclusive) for valid K per query position, shape [M], dtype int32. | required |
clean_logits | bool | Whether to clean the unfilled logits into | required |
Returns:
| Type | Description |
|---|---|
Tensor | Logits tensor of shape [M, N], dtype |
Source code in vllm/utils/deep_gemm.py
fp8_fp4_paged_mqa_logits ¶
fp8_fp4_paged_mqa_logits(
q: tuple[Tensor, Tensor | None],
kv_cache: Tensor,
weights: Tensor,
context_lens: Tensor,
block_tables: Tensor,
schedule_metadata: Tensor,
max_model_len: int,
clean_logits: bool,
) -> Tensor
Compute MQA logits using a paged KV-cache.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes q = (values, scales_or_None); pass (q_tensor, None) for the FP8 path and (q_values, q_scale) for MXFP4.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q | tuple[Tensor, Tensor | None] | Tuple | required |
kv_cache | Tensor | Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, D+4], dtype | required |
weights | Tensor | Tensor of shape [B * next_n, H], dtype | required |
context_lens | Tensor | Tensor of shape [B], dtype int32; effective context length for each batch element. | required |
block_tables | Tensor | Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache. | required |
schedule_metadata | Tensor | Returned by | required |
max_model_len | int | Maximum sequence length used to size the logits output. | required |
clean_logits | bool | Whether to clean the unfilled logits into | required |
Returns:
| Type | Description |
|---|---|
Tensor | Logits tensor of shape [B * next_n, max_model_len], dtype |
Tensor |
|
Source code in vllm/utils/deep_gemm.py
get_col_major_tma_aligned_tensor ¶
Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor
Source code in vllm/utils/deep_gemm.py
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor ¶
Grouped (3D, expert-batched) variant of get_mn_major_tma_aligned_packed_ue8m0_tensor. Use for MoE weight scale tensors of shape (num_experts, mn, k_scale).
Source code in vllm/utils/deep_gemm.py
get_mn_major_tma_aligned_packed_ue8m0_tensor ¶
Pack UE8M0 (uint8) → int32 with the MN-major TMA-aligned layout the DeepGEMM kernels consume directly. 16× smaller than the fp32 legacy SF format. Use for non-grouped 2D scale tensors.
Source code in vllm/utils/deep_gemm.py
get_paged_mqa_logits_metadata ¶
Build scheduling metadata for paged MQA logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context_lens | Tensor | Tensor of shape [B], dtype int32; effective context length per batch element. | required |
block_size | int | KV-cache block size in tokens (e.g., 64). | required |
num_sms | int | Number of SMs available. 132 for Hopper | required |
Returns:
| Type | Description |
|---|---|
Tensor | Backend-specific tensor consumed by |
Tensor | schedule work across SMs. |
Source code in vllm/utils/deep_gemm.py
get_theoretical_mk_alignment_for_contiguous_layout ¶
get_theoretical_mk_alignment_for_contiguous_layout(
expected_m: int | None = None,
num_groups: int | None = None,
) -> int
Per-call optimal M alignment for grouped contiguous GEMMs.
expected_m is the TOTAL routed tokens (sum across experts, typically M × num_topk). num_groups is the number of experts on this rank. The helper divides to recover per-expert em and picks an alignment based on data-driven thresholds (see deep_gemm runtime.hpp comments).
Older callers that omit num_groups are interpreted as passing already per-expert em (legacy behaviour preserved for backward compat).
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_e8m0_used cached ¶
is_deep_gemm_e8m0_used() -> bool
Return True if vLLM is configured to use DeepGEMM " "E8M0 scale on a Hopper or Blackwell-class GPU.
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_supported cached ¶
is_deep_gemm_supported() -> bool
Return True if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported.
Source code in vllm/utils/deep_gemm.py
mk_alignment_scope ¶
mk_alignment_scope(value: int)
Temporarily set DeepGEMM's BLOCK_M cap, restoring on exit.
Use around a sequence of grouped-contiguous GEMM calls whose workspace is padded to value (typically the per_call_align returned by compute_aligned_M_and_alignment).
Source code in vllm/utils/deep_gemm.py
pack_ue8m0_to_int ¶
Pack 4 UE8M0 (uint8) scales into one int32.
DeepGEMM's SM100/SM120 FP8/FP4 kernels accept either float32 scales (legacy format, 4 B/scale) or int32 packed UE8M0 scales (1 B/scale after 4:1 packing — 4× smaller than the legacy fp32 representation).
Source code in vllm/utils/deep_gemm.py
set_mk_alignment_for_contiguous_layout ¶
set_mk_alignment_for_contiguous_layout(value: int) -> None
Set DeepGEMM's BLOCK_M cap for grouped contiguous GEMMs.
The DG heuristic constrains BLOCK_M ≤ this value when picking a kernel layout. Use this in concert with compute_aligned_M_and_alignment's per-call alignment so the workspace's per-expert padding matches the kernel's BLOCK_M; a mismatch leads to the scheduler reading the wrong expert_id from m_indices at m_block_idx * BLOCK_M stride and OOB-indexing the B-weights tensor (manifests as IMA under CUDA-graph replay).
Source code in vllm/utils/deep_gemm.py
should_auto_disable_deep_gemm ¶
Check if DeepGemm should be auto-disabled for this model on Blackwell.
Returns True if the model is known to have accuracy degradation with DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+).
Source code in vllm/utils/deep_gemm.py
tf32_hc_prenorm_gemm ¶
tf32_hc_prenorm_gemm(
x: Tensor,
fn: Tensor,
out: Tensor,
sqrsum: Tensor,
num_split: int,
) -> Tensor
Perform the following computation
out = x.float() @ fn.T sqrsum = x.float().square().sum(-1)
See the caller function for shape requirement