diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 6a62440d95417..783e02ce89bdb 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import MLAAttentionSpec +from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ AttentionBackendEnum.CUTLASS_MLA, @@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): def run_attention_backend( backend: AttentionBackendEnum, - kv_cache_spec: MLAAttentionSpec, + kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, device: torch.device, @@ -740,7 +740,7 @@ def test_backend_correctness( kv_cache = kv_cache_per_block_size[block_size] # Create kv_cache_spec with the correct block_size for this backend - backend_kv_cache_spec = MLAAttentionSpec( + backend_kv_cache_spec = FullAttentionSpec( block_size=block_size, num_kv_heads=vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config @@ -748,7 +748,6 @@ def test_backend_correctness( head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, sliding_window=vllm_config.model_config.get_sliding_window(), - cache_dtype_str=vllm_config.cache_config.cache_dtype, ) backend_output = run_attention_backend( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 16598d567848f..1d410316d6299 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe( local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, + tile_tokens_dim=None, routing_method_type=routing_method_type, do_finalize=True, )[0] diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 619c1c2794daa..e9ec96835f277 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): metadata_cls if metadata_cls is not None else MLACommonMetadata ) self.kv_cache_spec = kv_cache_spec - self.q_data_type = ( - current_platform.fp8_dtype() - if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str) - else vllm_config.model_config.dtype - ) scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # For main run, qo_indptr == kv_indptr kv_indptr = qo_indptr.clone() + # Prepare main prefill self._fi_prefill_main.plan( qo_indptr=qo_indptr, @@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.model_config.dtype, ) # Prepare context prefills @@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.model_config.dtype, ) prefill.prefill_main = self._fi_prefill_main @@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, - q_data_type=self.q_data_type, ) if self._use_cudnn_prefill: @@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out def _run_prefill_new_tokens_fa( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running FlashAttention prefill new tokens", scope="local") return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_new_tokens_fi( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running FlashInfer prefill new tokens", scope="local") assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None - if fp8_attention: - logger.debug_once("Running Flashinfer prefill in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) + ret = prefill.prefill_main.run( q=q, k=k, @@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return ret def _run_prefill_new_tokens_cudnn( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running Cudnn prefill new tokens", scope="local") assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None - assert fp8_attention is False, "Cudnn prefill does not support fp8 attention" output, lse = cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, @@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return output def _run_prefill_context_chunk_fa( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running FlashAttention prefill context chunk", scope="local") assert prefill.chunked_context is not None - assert fp8_attention is False, ( - "FlashAttention prefill does not support fp8 attention" - ) return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_context_chunk_fi( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running FlashInfer prefill context chunk", scope="local") assert isinstance(prefill, FlashInferPrefillMetadata) - if fp8_attention: - logger.debug_once("Running FlashInfer prefill in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, @@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out, lse.transpose(0, 1).contiguous() def _run_prefill_context_chunk_cudnn( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running Cudnn prefill context chunk", scope="local") assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.query_seq_lens is not None - assert fp8_attention is False, "Cudnn prefill does not support fp8 attention" return cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, @@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_new_tokens_trtllm_ragged( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local") """TRT-LLM ragged attention for new tokens (causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek assert prefill.query_seq_lens is not None assert prefill.workspace_buffer is not None - if fp8_attention: - logger.debug_once("Running TRT-LLM ragged prefill in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) - ret = trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return ret def _run_prefill_context_chunk_trtllm_ragged( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local") """TRT-LLM ragged attention for context chunks (non-causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek @@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) prefill.workspace_buffer.fill_(0) - if fp8_attention: - logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) - attn_out, lse = trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, - fp8_attention: bool, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q=q, k=k, v=v, - fp8_attention=fp8_attention, ) if output is None: @@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, dcp_world_size: int, - fp8_attention: bool, ): assert k_scale is None, "DCP not support scaled kvcache now." assert attn_metadata.prefill is not None @@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q=q, k=k, v=v, - fp8_attention=fp8_attention, ) if output is None: @@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, output: torch.Tensor, - fp8_attention: bool = False, ) -> None: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None @@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): k=k, v=v, return_softmax_lse=has_context, - fp8_attention=fp8_attention, ) if has_context: @@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata, k_scale=None, dcp_world_size=self.dcp_world_size, - fp8_attention=fp8_attention, ) ) else: context_output, context_lse = self._compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention + q, kv_c_and_k_pe_cache, attn_metadata, k_scale ) # unpad if necessary @@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata, layer._k_scale, output=output[num_decode_tokens:], - fp8_attention=fp8_attention, ) if has_decode: