From 0f67d4d962872767ac1fca8e98d1bb679aae762a Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Fri, 24 Oct 2025 10:24:08 -0700 Subject: [PATCH] [Attention] Add MLA prefill backend: trtllm_ragged_attention_deepseek (#26397) Signed-off-by: Ming Yang --- vllm/envs.py | 6 ++ vllm/v1/attention/backends/mla/common.py | 109 ++++++++++++++++++++++- 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 45ce15d5ffb7..e6cef075528f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -183,6 +183,7 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False + VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False @@ -1250,6 +1251,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) ), + # Controls whether to use TRT-LLM ragged DeepSeek prefill + "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool( + int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0")) + ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. # If not set, auto-detect the attention backend in flashinfer. @@ -1481,6 +1486,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_CUDNN_PREFILL", + "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "VLLM_ROCM_USE_AITER", diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 51a9032f4269..b920fd929e85 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -371,6 +371,7 @@ class MLACommonPrefillMetadata: query_start_loc: torch.Tensor max_query_len: int chunked_context: ChunkedContextMetadata | None = None + query_seq_lens: torch.Tensor | None = None @dataclass @@ -386,7 +387,6 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata): class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): seq_lens: torch.Tensor - query_seq_lens: torch.Tensor | None = None cudnn_workspace: torch.Tensor | None = None @@ -457,6 +457,7 @@ def use_flashinfer_prefill() -> bool: not envs.VLLM_DISABLE_FLASHINFER_PREFILL and flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL + and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL and current_platform.is_device_capability(100) ) @@ -470,6 +471,15 @@ def use_cudnn_prefill() -> bool: ) +def use_trtllm_ragged_deepseek_prefill() -> bool: + """Check if TRT-LLM ragged DeepSeek prefill should be used.""" + return ( + flashinfer_available + and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL + and current_platform.is_device_capability(100) + ) + + # Currently 394MB, this can be tuned based on GEMM sizes used. # Chosen to be the same as sglang: # https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 @@ -593,6 +603,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() + self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill() self.prefill_metadata_cls = ( FlashInferPrefillMetadata if self._use_fi_prefill @@ -613,6 +624,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) ) + if self._use_trtllm_ragged_prefill: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + ) + if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, @@ -934,6 +950,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ) prefill_metadata.cudnn_workspace = self.cudnn_workspace + if self._use_trtllm_ragged_prefill: + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) + decode_metadata = None if num_decodes > 0: decode_metadata = self._build_decode( @@ -1230,6 +1251,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi self._pad_v = False + elif use_trtllm_ragged_deepseek_prefill(): + logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA") + self._run_prefill_context_chunk = ( + self._run_prefill_context_chunk_trtllm_ragged + ) + self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged + self._pad_v = False elif use_cudnn_prefill(): logger.debug_once("Using CUDNN prefill for MLA") self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn @@ -1326,6 +1354,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None + ret = prefill.prefill_main.run( q=q, k=k, @@ -1334,7 +1363,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) if isinstance(ret, tuple): - # Convert from (q_len, num_heads) to (num_heads, q_len) return ret[0], ret[1].transpose(0, 1).contiguous() return ret @@ -1384,12 +1412,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): assert isinstance(prefill, FlashInferPrefillMetadata) + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, v=v, return_lse=True, ) + # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() @@ -1418,6 +1448,81 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): is_cuda_graph_compatible=True, ) + def _run_prefill_new_tokens_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + """TRT-LLM ragged attention for new tokens (causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.query_seq_lens is not None + + ret = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.query_seq_lens, + max_q_len=prefill.max_query_len, + max_kv_len=prefill.max_query_len, + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.query_seq_lens.shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.query_start_loc, + enable_pdl=False, + is_causal=True, + return_lse=return_softmax_lse, + ) + + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_context_chunk_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + """TRT-LLM ragged attention for context chunks (non-causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + + out = torch.zeros( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=q.dtype, + ) + self._workspace_buffer.fill_(0) + + attn_out, lse = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.chunked_context.seq_lens[chunk_idx], + max_q_len=prefill.max_query_len, + max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx], + enable_pdl=False, + is_causal=False, + return_lse=True, + out=out, + ) + + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed")