Revert "[SM100] Enable fp8 compute for prefill MLA (#30746)" (#31197)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety 2025-12-22 18:15:33 -08:00 committed by GitHub
parent 612d5ffdab
commit 3e10262356
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 116 deletions

View File

@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA,
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend(
backend: AttentionBackendEnum,
kv_cache_spec: MLAAttentionSpec,
kv_cache_spec: FullAttentionSpec,
layer_names: list[str],
vllm_config,
device: torch.device,
@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = MLAAttentionSpec(
backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
@ -748,7 +748,6 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
backend_output = run_attention_backend(

View File

@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]

View File

@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls if metadata_cls is not None else MLACommonMetadata
)
self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()
# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
q_data_type=self.model_config.dtype,
)
# Prepare context prefills
@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
q_data_type=self.model_config.dtype,
)
prefill.prefill_main = self._fi_prefill_main
@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
)
if self._use_cudnn_prefill:
@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out
def _run_prefill_new_tokens_fa(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_fi(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = prefill.prefill_main.run(
q=q,
k=k,
@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_new_tokens_cudnn(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return output
def _run_prefill_context_chunk_fa(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_context_chunk_fi(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_trtllm_ragged(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = trtllm_ragged_attention_deepseek(
query=q,
key=k,
@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_context_chunk_trtllm_ragged(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
prefill.workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = trtllm_ragged_attention_deepseek(
query=q,
key=k,
@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
fp8_attention: bool,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
fp8_attention: bool,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
output: torch.Tensor,
fp8_attention: bool = False,
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k,
v=v,
return_softmax_lse=has_context,
fp8_attention=fp8_attention,
)
if has_context:
@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
k_scale=None,
dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
)
)
else:
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)
# unpad if necessary
@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
layer._k_scale,
output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
)
if has_decode: