Skip to content

vllm.model_executor.warmup.kernel_warmup

Warmup kernels used during model execution. This is useful specifically for JIT'ed kernels as we don't want JIT'ing to happen during model execution.

_deepseek_v4_request_prep_warmup

_deepseek_v4_request_prep_warmup(worker: Worker) -> None

Pre-JIT the slot-mapping kernel before the first real request.

Source code in vllm/model_executor/warmup/kernel_warmup.py
@torch.inference_mode()
def _deepseek_v4_request_prep_warmup(worker: "Worker") -> None:
    """Pre-JIT the slot-mapping kernel before the first real request."""
    if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP:
        return

    runner = worker.model_runner
    if runner.is_pooling_model or not _has_deepseek_v4_sparse_mla_backend(runner):
        return
    if not current_platform.is_cuda_alike():
        return

    logger.info("Warming up DeepSeek V4 request preparation kernels.")
    _deepseek_v4_slot_mapping_warmup(runner)
    torch.accelerator.synchronize()

_deepseek_v4_slot_mapping_warmup

_deepseek_v4_slot_mapping_warmup(
    runner: GPUModelRunner,
) -> None

Pre-JIT _compute_slot_mapping_kernel across decode-shaped sizes.

Source code in vllm/model_executor/warmup/kernel_warmup.py
def _deepseek_v4_slot_mapping_warmup(runner: "GPUModelRunner") -> None:
    """Pre-JIT `_compute_slot_mapping_kernel` across decode-shaped sizes."""
    max_tokens = getattr(runner, "max_num_tokens", 1)
    block_table = runner.input_batch.block_table

    # Snapshot the runner buffers we mutate so warmup doesn't leak state.
    saved_query_start_loc_np = None
    saved_query_start_loc_gpu = None
    if hasattr(runner, "query_start_loc"):
        saved_query_start_loc_np = runner.query_start_loc.np[:2].copy()
        saved_query_start_loc_gpu = runner.query_start_loc.gpu[:2].clone()

    try:
        for requested_tokens in _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS:
            num_tokens = _clamp_warmup_tokens(requested_tokens, max_tokens)
            if num_tokens <= 0:
                continue

            positions_source = torch.arange(
                num_tokens, dtype=torch.int64, device=runner.device
            )
            if hasattr(runner, "query_start_loc"):
                runner.query_start_loc.np[0] = 0
                runner.query_start_loc.np[1] = num_tokens
                runner.query_start_loc.copy_to_gpu(2)
                query_start_loc = runner.query_start_loc.gpu[:2]
            else:
                query_start_loc = torch.tensor(
                    [0, num_tokens], dtype=torch.int32, device=runner.device
                )

            if hasattr(runner, "positions"):
                saved_positions = runner.positions[:num_tokens].clone()
                runner.positions[:num_tokens].copy_(positions_source)
                positions = runner.positions[:num_tokens]
            else:
                saved_positions = None
                positions = positions_source

            try:
                block_table.commit_block_table(1)
                block_table.compute_slot_mapping(1, query_start_loc, positions)
            finally:
                if saved_positions is not None:
                    runner.positions[:num_tokens].copy_(saved_positions)
    finally:
        if saved_query_start_loc_np is not None:
            runner.query_start_loc.np[:2] = saved_query_start_loc_np
            assert saved_query_start_loc_gpu is not None
            runner.query_start_loc.gpu[:2].copy_(saved_query_start_loc_gpu)

_deepseek_v4_sparse_mla_attention_warmup

_deepseek_v4_sparse_mla_attention_warmup(
    worker: Worker,
) -> None

Warm sparse-MLA attention shapes via _dummy_run.

Three shapes: mixed prefill+decode, single max-chunk prefill, and a second-chunk prefill (prior context) — the last covers _build_prefill_chunk_metadata_kernel's alt-shape specialization.

Source code in vllm/model_executor/warmup/kernel_warmup.py
def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None:
    """Warm sparse-MLA attention shapes via `_dummy_run`.

    Three shapes: mixed prefill+decode, single max-chunk prefill, and a
    second-chunk prefill (prior context) — the last covers
    `_build_prefill_chunk_metadata_kernel`'s alt-shape specialization.
    """
    if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP:
        return

    runner = worker.model_runner
    if runner.is_pooling_model or not _has_deepseek_v4_sparse_mla_backend(runner):
        return

    max_tokens = worker.scheduler_config.max_num_batched_tokens
    mixed_tokens = _clamp_warmup_tokens(
        _DEEPSEEK_V4_SPARSE_MLA_MIXED_WARMUP_TOKENS, max_tokens
    )
    prefill_tokens = _clamp_warmup_tokens(
        _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS, max_tokens
    )
    if mixed_tokens <= 0 and prefill_tokens <= 0:
        return

    logger.info(
        "Warming up DeepSeek V4 sparse MLA attention "
        "for mixed tokens=%s and prefill tokens=%s.",
        mixed_tokens,
        prefill_tokens,
    )
    if mixed_tokens > 0:
        mixed_warmup_done = _deepseek_v4_sparse_mla_decode_autotune(
            worker, mixed_tokens
        )
        if not mixed_warmup_done:
            runner._dummy_run(
                num_tokens=mixed_tokens,
                skip_eplb=True,
                is_profile=True,
                force_attention=True,
                create_mixed_batch=True,
            )
    if prefill_tokens > 0:
        runner._dummy_run(
            num_tokens=prefill_tokens,
            skip_eplb=True,
            is_profile=True,
            force_attention=True,
            create_single_prefill=True,
        )
        # Second-chunk shape: indexer sees prior context, hits the alt
        # specialization of `_build_prefill_chunk_metadata_kernel`.
        runner._dummy_run(
            num_tokens=prefill_tokens,
            skip_eplb=True,
            is_profile=True,
            force_attention=True,
            create_single_prefill=True,
            profile_seq_lens=prefill_tokens * 2,
        )

_deepseek_v4_sparse_mla_decode_autotune

_deepseek_v4_sparse_mla_decode_autotune(
    worker: Worker, num_tokens: int
) -> bool

Autotune FlashInfer's DSv4 SM120 sparse-MLA decode path.

Returns True when this function consumed the mixed attention warmup shape.

Source code in vllm/model_executor/warmup/kernel_warmup.py
def _deepseek_v4_sparse_mla_decode_autotune(
    worker: "Worker",
    num_tokens: int,
) -> bool:
    """Autotune FlashInfer's DSv4 SM120 sparse-MLA decode path.

    Returns True when this function consumed the mixed attention warmup shape.
    """
    if worker.vllm_config.kernel_config.enable_flashinfer_autotune is not True:
        return False
    if not has_flashinfer() or not current_platform.is_device_capability_family(120):
        return False

    try:
        from flashinfer import sparse_mla_sm120_decode_dsv4_autotune
        from flashinfer.autotuner import AutoTuner
    except ImportError:
        logger.warning(
            "Skipping DeepSeek V4 sparse MLA decode autotune because this "
            "FlashInfer build does not expose sparse_mla_sm120_decode_dsv4_autotune."
        )
        return False

    from vllm.distributed.parallel_state import get_world_group

    runner = worker.model_runner
    world = get_world_group()
    is_leader = world.rank_in_group == 0
    cache_path = _resolve_flashinfer_autotune_file(runner)

    dummy_run_kwargs = dict(
        num_tokens=num_tokens,
        skip_eplb=True,
        is_profile=True,
        force_attention=True,
        create_mixed_batch=True,
    )

    if is_leader:
        logger.info(
            "Autotuning DeepSeek V4 SM120 sparse MLA decode with FlashInfer "
            "cache file: %s",
            cache_path,
        )

    with torch.inference_mode():
        if is_leader:
            with sparse_mla_sm120_decode_dsv4_autotune(cache_path=str(cache_path)):
                runner._dummy_run(**dummy_run_kwargs)
        else:
            runner._dummy_run(**dummy_run_kwargs)

    tune_results: bytes | None = None
    if is_leader and cache_path.exists():
        with open(cache_path, "rb") as f:
            tune_results = f.read()

    tune_results = world.broadcast_object(tune_results, src=0)
    if tune_results is None:
        logger.warning(
            "No DeepSeek V4 sparse MLA decode autotune cache entries found. "
            "Falling back to FlashInfer's default tactic heuristic."
        )
        world.barrier()
        return True

    if not is_leader and world.local_rank == 0:
        cache_path.parent.mkdir(parents=True, exist_ok=True)
        with open(cache_path, "wb") as f:
            f.write(tune_results)
    world.barrier()

    AutoTuner.get().load_configs(str(cache_path))
    logger.info(
        "DeepSeek V4 sparse MLA decode autotune cache loaded on rank %d from %s.",
        world.rank_in_group,
        cache_path,
    )
    return True

flashinfer_autotune

flashinfer_autotune(runner: GPUModelRunner) -> None

Autotune FlashInfer operations. FlashInfer have many implementations for the same operation, autotuning runs benchmarks for each implementation and stores the results. The results are cached transparently and future calls to FlashInfer will use the best implementation. Without autotuning, FlashInfer will rely on heuristics, which may be significantly slower.

Tuning is performed only on rank 0. The resulting cache is broadcast to every rank so all ranks dispatch the same kernel tactic.

Source code in vllm/model_executor/warmup/kernel_warmup.py
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
    """
    Autotune FlashInfer operations.
    FlashInfer have many implementations for the same operation,
    autotuning runs benchmarks for each implementation and stores
    the results. The results are cached transparently and
    future calls to FlashInfer will use the best implementation.
    Without autotuning, FlashInfer will rely on heuristics, which may
    be significantly slower.

    Tuning is performed only on rank 0. The resulting cache is broadcast
    to every rank so all ranks dispatch the same kernel tactic.
    """
    import vllm.utils.flashinfer as fi_utils
    from vllm.distributed.parallel_state import get_world_group

    if not _FLASHINFER_USE_PERSISTENT_CACHE:
        with torch.inference_mode(), fi_utils.autotune():
            runner._dummy_run(
                num_tokens=runner.scheduler_config.max_num_batched_tokens,
                skip_eplb=True,
                is_profile=True,
            )
        get_world_group().barrier()
        return

    world = get_world_group()
    is_leader = world.rank_in_group == 0

    cache_path = _resolve_flashinfer_autotune_file(runner)
    if is_leader:
        logger.info("Using FlashInfer autotune cache file: %s", cache_path)

    # We skip EPLB here since we don't want to record dummy metrics.
    # When autotuning with number of tokens m, flashinfer will autotune
    # operations for all number of tokens up to m, so we only need to
    # run with the max number of tokens.
    dummy_run_kwargs = dict(
        num_tokens=runner.scheduler_config.max_num_batched_tokens,
        skip_eplb=True,
        is_profile=True,
    )

    with torch.inference_mode():
        if is_leader:
            with fi_utils.autotune(tune_mode=True, cache=str(cache_path)):
                runner._dummy_run(**dummy_run_kwargs)
        else:
            runner._dummy_run(**dummy_run_kwargs)

    # Broadcast autotune cache from rank 0 to all other ranks so every
    # rank loads the same set of chosen tactics.
    tune_results: bytes | None = None
    if is_leader and cache_path.exists():
        with open(cache_path, "rb") as f:
            tune_results = f.read()

    tune_results = world.broadcast_object(tune_results, src=0)

    if tune_results is None:
        logger.warning(
            "No FlashInfer autotune cache entries found."
            "Falling back to default tactics."
        )
    else:
        if not is_leader and world.local_rank == 0:
            with open(cache_path, "wb") as f:
                f.write(tune_results)
        world.barrier()
        from flashinfer.autotuner import AutoTuner

        AutoTuner.get().load_configs(str(cache_path))
        logger.info(
            "FlashInfer autotune cache loaded on rank %d from %s.",
            world.rank_in_group,
            cache_path,
        )