[SM100] Enable fp8 compute for prefill MLA (#30746)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety 2025-12-22 11:15:57 -08:00 committed by GitHub
parent 7b926e8901
commit b10f41c894
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 117 additions and 18 deletions

View File

@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.kv_cache_interface import MLAAttentionSpec
BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA,
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend(
backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config,
device: torch.device,
@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec(
backend_kv_cache_spec = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
@ -748,6 +748,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
backend_output = run_attention_backend(

View File

@ -325,7 +325,6 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]

View File

@ -355,6 +355,7 @@ class MLACommonPrefillMetadata:
max_query_len: int
chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
@dataclass
@ -539,6 +540,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls if metadata_cls is not None else MLACommonMetadata
)
self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
@ -681,7 +687,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()
# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
@ -694,7 +699,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
q_data_type=self.q_data_type,
)
# Prepare context prefills
@ -713,7 +718,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
q_data_type=self.q_data_type,
)
prefill.prefill_main = self._fi_prefill_main
@ -970,6 +975,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
)
if self._use_cudnn_prefill:
@ -1370,8 +1376,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out
def _run_prefill_new_tokens_fa(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1386,11 +1399,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_fi(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = prefill.prefill_main.run(
q=q,
k=k,
@ -1403,10 +1428,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_new_tokens_cudnn(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1428,9 +1461,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return output
def _run_prefill_context_chunk_fa(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1445,10 +1488,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_context_chunk_fi(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
@ -1460,12 +1515,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1485,13 +1548,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = trtllm_ragged_attention_deepseek(
query=q,
key=k,
@ -1518,8 +1595,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_context_chunk_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
@ -1535,6 +1619,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
self._workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = trtllm_ragged_attention_deepseek(
query=q,
key=k,
@ -1687,6 +1778,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
fp8_attention: bool,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
@ -1725,6 +1817,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1753,6 +1846,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
fp8_attention: bool,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
@ -1829,6 +1923,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1859,6 +1954,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
output: torch.Tensor,
fp8_attention: bool = False,
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
@ -1878,6 +1974,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k,
v=v,
return_softmax_lse=has_context,
fp8_attention=fp8_attention,
)
if has_context:
@ -1890,11 +1987,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
k_scale=None,
dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
)
)
else:
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
)
# unpad if necessary
@ -2015,6 +2113,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
layer._k_scale,
output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
)
if has_decode: