diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 783e02ce89bdb..6a62440d95417 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import FullAttentionSpec +from vllm.v1.kv_cache_interface import MLAAttentionSpec BACKENDS_TO_TEST = [ AttentionBackendEnum.CUTLASS_MLA, @@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): def run_attention_backend( backend: AttentionBackendEnum, - kv_cache_spec: FullAttentionSpec, + kv_cache_spec: MLAAttentionSpec, layer_names: list[str], vllm_config, device: torch.device, @@ -740,7 +740,7 @@ def test_backend_correctness( kv_cache = kv_cache_per_block_size[block_size] # Create kv_cache_spec with the correct block_size for this backend - backend_kv_cache_spec = FullAttentionSpec( + backend_kv_cache_spec = MLAAttentionSpec( block_size=block_size, num_kv_heads=vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config @@ -748,6 +748,7 @@ def test_backend_correctness( head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, sliding_window=vllm_config.model_config.get_sliding_window(), + cache_dtype_str=vllm_config.cache_config.cache_dtype, ) backend_output = run_attention_backend( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 1d410316d6299..16598d567848f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -325,7 +325,6 @@ def flashinfer_trtllm_fp4_moe( local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, - tile_tokens_dim=None, routing_method_type=routing_method_type, do_finalize=True, )[0] diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fea482493635f..0ff951213040a 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -355,6 +355,7 @@ class MLACommonPrefillMetadata: max_query_len: int chunked_context: ChunkedContextMetadata | None = None query_seq_lens: torch.Tensor | None = None + q_data_type: torch.dtype | None = None @dataclass @@ -539,6 +540,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): metadata_cls if metadata_cls is not None else MLACommonMetadata ) self.kv_cache_spec = kv_cache_spec + self.q_data_type = ( + current_platform.fp8_dtype() + if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str) + else vllm_config.model_config.dtype + ) scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -681,7 +687,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # For main run, qo_indptr == kv_indptr kv_indptr = qo_indptr.clone() - # Prepare main prefill self._fi_prefill_main.plan( qo_indptr=qo_indptr, @@ -694,7 +699,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.model_config.dtype, + q_data_type=self.q_data_type, ) # Prepare context prefills @@ -713,7 +718,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.model_config.dtype, + q_data_type=self.q_data_type, ) prefill.prefill_main = self._fi_prefill_main @@ -970,6 +975,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, + q_data_type=self.q_data_type, ) if self._use_cudnn_prefill: @@ -1370,8 +1376,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out def _run_prefill_new_tokens_fa( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + self, + prefill: MLACommonPrefillMetadata, + q, + k, + v, + return_softmax_lse, + fp8_attention: bool, ): + logger.debug_once("Running FlashAttention prefill new tokens", scope="local") return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1386,11 +1399,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_new_tokens_fi( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + self, + prefill: MLACommonPrefillMetadata, + q, + k, + v, + return_softmax_lse, + fp8_attention: bool, ): + logger.debug_once("Running FlashInfer prefill new tokens", scope="local") assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None - + if fp8_attention: + logger.debug_once("Running Flashinfer prefill in FP8") + fp8_dtype = current_platform.fp8_dtype() + q = q.to(fp8_dtype) + k = k.to(fp8_dtype) + v = v.to(fp8_dtype) ret = prefill.prefill_main.run( q=q, k=k, @@ -1403,10 +1428,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return ret def _run_prefill_new_tokens_cudnn( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + self, + prefill: MLACommonPrefillMetadata, + q, + k, + v, + return_softmax_lse, + fp8_attention: bool, ): + logger.debug_once("Running Cudnn prefill new tokens", scope="local") assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None + assert fp8_attention is False, "Cudnn prefill does not support fp8 attention" output, lse = cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, @@ -1428,9 +1461,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return output def _run_prefill_context_chunk_fa( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + self, + prefill: MLACommonPrefillMetadata, + chunk_idx: int, + q, + k, + v, + fp8_attention: bool, ): + logger.debug_once("Running FlashAttention prefill context chunk", scope="local") assert prefill.chunked_context is not None + assert fp8_attention is False, ( + "FlashAttention prefill does not support fp8 attention" + ) return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1445,10 +1488,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_context_chunk_fi( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + self, + prefill: MLACommonPrefillMetadata, + chunk_idx: int, + q, + k, + v, + fp8_attention: bool, ): + logger.debug_once("Running FlashInfer prefill context chunk", scope="local") assert isinstance(prefill, FlashInferPrefillMetadata) - + if fp8_attention: + logger.debug_once("Running FlashInfer prefill in FP8") + fp8_dtype = current_platform.fp8_dtype() + q = q.to(fp8_dtype) + k = k.to(fp8_dtype) + v = v.to(fp8_dtype) attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, @@ -1460,12 +1515,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out, lse.transpose(0, 1).contiguous() def _run_prefill_context_chunk_cudnn( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + self, + prefill: MLACommonPrefillMetadata, + chunk_idx: int, + q, + k, + v, + fp8_attention: bool, ): + logger.debug_once("Running Cudnn prefill context chunk", scope="local") assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.query_seq_lens is not None + assert fp8_attention is False, "Cudnn prefill does not support fp8 attention" return cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, @@ -1485,13 +1548,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_new_tokens_trtllm_ragged( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + self, + prefill: MLACommonPrefillMetadata, + q, + k, + v, + return_softmax_lse, + fp8_attention: bool, ): + logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local") """TRT-LLM ragged attention for new tokens (causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek assert prefill.query_seq_lens is not None + if fp8_attention: + logger.debug_once("Running TRT-LLM ragged prefill in FP8") + fp8_dtype = current_platform.fp8_dtype() + q = q.to(fp8_dtype) + k = k.to(fp8_dtype) + v = v.to(fp8_dtype) + ret = trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1518,8 +1595,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return ret def _run_prefill_context_chunk_trtllm_ragged( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + self, + prefill: MLACommonPrefillMetadata, + chunk_idx: int, + q, + k, + v, + fp8_attention: bool, ): + logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local") """TRT-LLM ragged attention for context chunks (non-causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek @@ -1535,6 +1619,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) self._workspace_buffer.fill_(0) + if fp8_attention: + logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8") + fp8_dtype = current_platform.fp8_dtype() + q = q.to(fp8_dtype) + k = k.to(fp8_dtype) + v = v.to(fp8_dtype) + attn_out, lse = trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1687,6 +1778,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, + fp8_attention: bool, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1725,6 +1817,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q=q, k=k, v=v, + fp8_attention=fp8_attention, ) if output is None: @@ -1753,6 +1846,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, dcp_world_size: int, + fp8_attention: bool, ): assert k_scale is None, "DCP not support scaled kvcache now." assert attn_metadata.prefill is not None @@ -1829,6 +1923,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q=q, k=k, v=v, + fp8_attention=fp8_attention, ) if output is None: @@ -1859,6 +1954,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, output: torch.Tensor, + fp8_attention: bool = False, ) -> None: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None @@ -1878,6 +1974,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): k=k, v=v, return_softmax_lse=has_context, + fp8_attention=fp8_attention, ) if has_context: @@ -1890,11 +1987,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata, k_scale=None, dcp_world_size=self.dcp_world_size, + fp8_attention=fp8_attention, ) ) else: context_output, context_lse = self._compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, k_scale + q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention ) # unpad if necessary @@ -2015,6 +2113,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata, layer._k_scale, output=output[num_decode_tokens:], + fp8_attention=fp8_attention, ) if has_decode: