Skip to content

vllm.model_executor.layers.fused_moe.deep_gemm_utils

Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00ba5c0cef8bff6/lightllm/common/fused_moe/deepep_scatter_gather.py and updated to fit vllm needs and terminology.

compute_aligned_M

compute_aligned_M(
    M: int,
    num_topk: int,
    local_num_experts: int,
    alignment: int,
    expert_tokens_meta: ExpertTokensMetadata | None,
) -> int

Return M_sum only (backward-compat wrapper).

Equivalent to :func:compute_aligned_M_and_alignment's first return value. Existing downstream callers and the warmup path that only size a workspace use this. Call sites that need the actual per-expert alignment (to wrap GEMMs in mk_alignment_scope) should use :func:compute_aligned_M_and_alignment instead.

Source code in vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
def compute_aligned_M(
    M: int,
    num_topk: int,
    local_num_experts: int,
    alignment: int,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> int:
    """Return ``M_sum`` only (backward-compat wrapper).

    Equivalent to :func:`compute_aligned_M_and_alignment`'s first return
    value. Existing downstream callers and the warmup path that only size
    a workspace use this. Call sites that need the actual per-expert
    alignment (to wrap GEMMs in ``mk_alignment_scope``) should use
    :func:`compute_aligned_M_and_alignment` instead.
    """
    M_sum, _ = compute_aligned_M_and_alignment(
        M, num_topk, local_num_experts, alignment, expert_tokens_meta
    )
    return M_sum

compute_aligned_M_and_alignment

compute_aligned_M_and_alignment(
    M: int,
    num_topk: int,
    local_num_experts: int,
    alignment: int,
    expert_tokens_meta: ExpertTokensMetadata | None,
) -> tuple[int, int]

Return (M_sum, alignment_used).

alignment_used may be smaller than the caller-supplied alignment on SM100/SM120 when DeepGEMM can JIT a smaller BLOCK_M for the per-call expected_m. Callers that index by block size (e.g. M_sum // block_m) or assert workspace alignment must use the returned alignment_used, not their original alignment argument.

Prefer this over the int-returning :func:compute_aligned_M when the GEMM call site needs to wrap itself in mk_alignment_scope or otherwise reason about the actual per-expert padding.

Source code in vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
def compute_aligned_M_and_alignment(
    M: int,
    num_topk: int,
    local_num_experts: int,
    alignment: int,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[int, int]:
    """Return (M_sum, alignment_used).

    `alignment_used` may be smaller than the caller-supplied `alignment` on
    SM100/SM120 when DeepGEMM can JIT a smaller BLOCK_M for the per-call
    expected_m. Callers that index by block size (e.g. ``M_sum // block_m``)
    or assert workspace alignment must use the returned `alignment_used`,
    not their original `alignment` argument.

    Prefer this over the int-returning :func:`compute_aligned_M` when the
    GEMM call site needs to wrap itself in ``mk_alignment_scope`` or
    otherwise reason about the actual per-expert padding.
    """
    if (expert_tokens_meta is not None) and (
        expert_tokens_meta.expert_num_tokens_cpu is not None
    ):
        return (
            expert_num_tokens_round_up_and_sum(
                expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment
            ),
            alignment,
        )

    # expert_num_tokens not on cpu. Cap padding by min(M*num_topk,
    # local_num_experts) — at batch=1 decode only `num_topk` experts can be
    # active, so the worst-case `local_num_experts*(align-1)` is too loose.
    # Also shrink `alignment` to DeepGEMM's per-call theoretical BLOCK_M on
    # SM100/SM120 when smaller.
    expected_m = M * num_topk
    try:
        from vllm.utils.deep_gemm import (
            get_theoretical_mk_alignment_for_contiguous_layout,
        )
        # num_groups=local_num_experts so the helper recovers per-expert em;
        # omitting it over-picks BLOCK_M on SM120 (heuristic assumes em is
        # already per-expert).
        per_call_align = get_theoretical_mk_alignment_for_contiguous_layout(
            expected_m=expected_m,
            num_groups=local_num_experts,
        )
        if per_call_align and per_call_align <= alignment:
            alignment = per_call_align
    except Exception:
        pass

    max_active_experts = min(M * num_topk, local_num_experts)
    M_sum = (M * num_topk) + max_active_experts * (alignment - 1)
    M_sum = round_up(M_sum, alignment)
    return M_sum, alignment