class FlashInferMLASparseSM120Impl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata]):
"""SM120 FlashInfer sparse-MLA implementation."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
if any([alibi_slopes, sliding_window, logits_soft_cap]):
raise NotImplementedError(
"FLASHINFER_MLA_SPARSE_SM120 does not support alibi_slopes / "
"sliding_window / logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"FLASHINFER_MLA_SPARSE_SM120 only supports decoder self-attention"
)
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
if self.kv_cache_dtype != "fp8_ds_mla":
raise NotImplementedError(
"FLASHINFER_MLA_SPARSE_SM120 requires the packed fp8_ds_mla "
f"KV cache layout; got kv_cache_dtype={kv_cache_dtype!r}."
)
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.qk_nope_head_dim: int = mla_args["qk_nope_head_dim"]
self.qk_rope_head_dim: int = mla_args["qk_rope_head_dim"]
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
model_type = None
if vllm_config.model_config is not None:
model_type = getattr(
vllm_config.model_config.hf_text_config, "model_type", None
)
self.kv_scale_format = _kv_scale_format_for_model(model_type)
assert indexer is not None, (
"FLASHINFER_MLA_SPARSE_SM120 requires a sparse-MLA indexer "
"(model with index_topk in its config)."
)
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
from vllm.utils.flashinfer import has_flashinfer_sparse_mla_sm120
if not has_flashinfer_sparse_mla_sm120():
raise RuntimeError(
"FLASHINFER_MLA_SPARSE_SM120 requires FlashInfer's "
"sparse MLA decode API."
)
assert self.topk_indices_buffer is not None
self.supports_quant_query_input = False
self._workspace_buffer: torch.Tensor | None = None
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashInferMLASparseMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
num_actual_toks = q.shape[0]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
topk_indices_physical = cast(
torch.Tensor,
triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token[:num_actual_toks],
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
),
)
output = q.new_empty(
(num_actual_toks, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
)
if self._workspace_buffer is None:
self._workspace_buffer = _get_workspace_buffer(q.device)
from vllm.utils.flashinfer import (
flashinfer_trtllm_batch_decode_with_kv_cache_mla,
)
out = flashinfer_trtllm_batch_decode_with_kv_cache_mla(
query=q.unsqueeze(1),
kv_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=topk_indices_physical.unsqueeze(1),
seq_lens=None,
max_seq_len=attn_metadata.topk_tokens,
out=output.unsqueeze(1),
bmm1_scale=self.scale,
bmm2_scale=1.0,
sparse_mla_top_k=attn_metadata.topk_tokens,
kv_scale_format=self.kv_scale_format,
)
return out.squeeze(1), None