[Attention] Add MLA prefill backend: trtllm_ragged_attention_deepseek (#26397)

Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
Ming Yang 2025-10-24 10:24:08 -07:00 committed by GitHub
parent 7e1d697b56
commit 0f67d4d962
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 2 deletions

View File

@ -183,6 +183,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
@ -1250,6 +1251,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
),
# Controls whether to use TRT-LLM ragged DeepSeek prefill
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool(
int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0"))
),
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
# If set to 0/False, use the default attention backend in flashinfer.
# If not set, auto-detect the attention backend in flashinfer.
@ -1481,6 +1486,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
"VLLM_ROCM_USE_AITER",

View File

@ -371,6 +371,7 @@ class MLACommonPrefillMetadata:
query_start_loc: torch.Tensor
max_query_len: int
chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None
@dataclass
@ -386,7 +387,6 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
seq_lens: torch.Tensor
query_seq_lens: torch.Tensor | None = None
cudnn_workspace: torch.Tensor | None = None
@ -457,6 +457,7 @@ def use_flashinfer_prefill() -> bool:
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
and flashinfer_available
and not envs.VLLM_USE_CUDNN_PREFILL
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
and current_platform.is_device_capability(100)
)
@ -470,6 +471,15 @@ def use_cudnn_prefill() -> bool:
)
def use_trtllm_ragged_deepseek_prefill() -> bool:
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
return (
flashinfer_available
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
and current_platform.is_device_capability(100)
)
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Chosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
@ -593,6 +603,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self._use_cudnn_prefill = use_cudnn_prefill()
self._use_fi_prefill = use_flashinfer_prefill()
self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
self.prefill_metadata_cls = (
FlashInferPrefillMetadata
if self._use_fi_prefill
@ -613,6 +624,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
)
if self._use_trtllm_ragged_prefill:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
)
if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty(
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
@ -934,6 +950,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
)
prefill_metadata.cudnn_workspace = self.cudnn_workspace
if self._use_trtllm_ragged_prefill:
prefill_metadata.query_seq_lens = (
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
)
decode_metadata = None
if num_decodes > 0:
decode_metadata = self._build_decode(
@ -1230,6 +1251,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self._pad_v = False
elif use_trtllm_ragged_deepseek_prefill():
logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
self._run_prefill_context_chunk = (
self._run_prefill_context_chunk_trtllm_ragged
)
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
self._pad_v = False
elif use_cudnn_prefill():
logger.debug_once("Using CUDNN prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
@ -1326,6 +1354,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
):
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
ret = prefill.prefill_main.run(
q=q,
k=k,
@ -1334,7 +1363,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
if isinstance(ret, tuple):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return ret[0], ret[1].transpose(0, 1).contiguous()
return ret
@ -1384,12 +1412,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
assert isinstance(prefill, FlashInferPrefillMetadata)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
v=v,
return_lse=True,
)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()
@ -1418,6 +1448,81 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
is_cuda_graph_compatible=True,
)
def _run_prefill_new_tokens_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
"""TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None
ret = trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
seq_lens=prefill.query_seq_lens,
max_q_len=prefill.max_query_len,
max_kv_len=prefill.max_query_len,
bmm1_scale=self.scale,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=prefill.query_seq_lens.shape[0],
window_left=-1,
cum_seq_lens_q=prefill.query_start_loc,
cum_seq_lens_kv=prefill.query_start_loc,
enable_pdl=False,
is_causal=True,
return_lse=return_softmax_lse,
)
if isinstance(ret, tuple):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return ret[0], ret[1].transpose(0, 1).contiguous()
return ret
def _run_prefill_context_chunk_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
out = torch.zeros(
q.shape[0],
q.shape[1],
v.shape[2],
device=q.device,
dtype=q.dtype,
)
self._workspace_buffer.fill_(0)
attn_out, lse = trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
max_q_len=prefill.max_query_len,
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
bmm1_scale=self.scale,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
window_left=-1,
cum_seq_lens_q=prefill.query_start_loc,
cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
enable_pdl=False,
is_causal=False,
return_lse=True,
out=out,
)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")