mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Attention] Add MLA prefill backend: trtllm_ragged_attention_deepseek (#26397)
Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
parent
7e1d697b56
commit
0f67d4d962
@ -183,6 +183,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||
VLLM_LOOPBACK_IP: str = ""
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
@ -1250,6 +1251,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
|
||||
),
|
||||
# Controls whether to use TRT-LLM ragged DeepSeek prefill
|
||||
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0"))
|
||||
),
|
||||
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
|
||||
# If set to 0/False, use the default attention backend in flashinfer.
|
||||
# If not set, auto-detect the attention backend in flashinfer.
|
||||
@ -1481,6 +1486,7 @@ def compute_hash() -> str:
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
||||
"VLLM_USE_CUDNN_PREFILL",
|
||||
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
||||
"VLLM_USE_TRTLLM_ATTENTION",
|
||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
||||
"VLLM_ROCM_USE_AITER",
|
||||
|
||||
@ -371,6 +371,7 @@ class MLACommonPrefillMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
chunked_context: ChunkedContextMetadata | None = None
|
||||
query_seq_lens: torch.Tensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -386,7 +387,6 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
|
||||
class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
query_seq_lens: torch.Tensor | None = None
|
||||
cudnn_workspace: torch.Tensor | None = None
|
||||
|
||||
|
||||
@ -457,6 +457,7 @@ def use_flashinfer_prefill() -> bool:
|
||||
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
|
||||
and flashinfer_available
|
||||
and not envs.VLLM_USE_CUDNN_PREFILL
|
||||
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
||||
and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
@ -470,6 +471,15 @@ def use_cudnn_prefill() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
||||
return (
|
||||
flashinfer_available
|
||||
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
||||
and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
||||
# Chosen to be the same as sglang:
|
||||
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
|
||||
@ -593,6 +603,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||
self._use_fi_prefill = use_flashinfer_prefill()
|
||||
self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
|
||||
self.prefill_metadata_cls = (
|
||||
FlashInferPrefillMetadata
|
||||
if self._use_fi_prefill
|
||||
@ -613,6 +624,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
|
||||
)
|
||||
|
||||
if self._use_trtllm_ragged_prefill:
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
self.cudnn_workspace = torch.empty(
|
||||
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
||||
@ -934,6 +950,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
)
|
||||
prefill_metadata.cudnn_workspace = self.cudnn_workspace
|
||||
|
||||
if self._use_trtllm_ragged_prefill:
|
||||
prefill_metadata.query_seq_lens = (
|
||||
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
@ -1230,6 +1251,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
||||
self._pad_v = False
|
||||
elif use_trtllm_ragged_deepseek_prefill():
|
||||
logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
|
||||
self._run_prefill_context_chunk = (
|
||||
self._run_prefill_context_chunk_trtllm_ragged
|
||||
)
|
||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
|
||||
self._pad_v = False
|
||||
elif use_cudnn_prefill():
|
||||
logger.debug_once("Using CUDNN prefill for MLA")
|
||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
|
||||
@ -1326,6 +1354,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
):
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
assert prefill.prefill_main is not None
|
||||
|
||||
ret = prefill.prefill_main.run(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1334,7 +1363,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||
return ret
|
||||
|
||||
@ -1384,12 +1412,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
|
||||
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
|
||||
@ -1418,6 +1448,81 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
is_cuda_graph_compatible=True,
|
||||
)
|
||||
|
||||
def _run_prefill_new_tokens_trtllm_ragged(
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
"""TRT-LLM ragged attention for new tokens (causal)."""
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
assert prefill.query_seq_lens is not None
|
||||
|
||||
ret = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
seq_lens=prefill.query_seq_lens,
|
||||
max_q_len=prefill.max_query_len,
|
||||
max_kv_len=prefill.max_query_len,
|
||||
bmm1_scale=self.scale,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=1.0,
|
||||
batch_size=prefill.query_seq_lens.shape[0],
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=prefill.query_start_loc,
|
||||
cum_seq_lens_kv=prefill.query_start_loc,
|
||||
enable_pdl=False,
|
||||
is_causal=True,
|
||||
return_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||
return ret
|
||||
|
||||
def _run_prefill_context_chunk_trtllm_ragged(
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
"""TRT-LLM ragged attention for context chunks (non-causal)."""
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
assert prefill.chunked_context is not None
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
|
||||
out = torch.zeros(
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
v.shape[2],
|
||||
device=q.device,
|
||||
dtype=q.dtype,
|
||||
)
|
||||
self._workspace_buffer.fill_(0)
|
||||
|
||||
attn_out, lse = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
|
||||
max_q_len=prefill.max_query_len,
|
||||
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
|
||||
bmm1_scale=self.scale,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=1.0,
|
||||
batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=prefill.query_start_loc,
|
||||
cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
|
||||
enable_pdl=False,
|
||||
is_causal=False,
|
||||
return_lse=True,
|
||||
out=out,
|
||||
)
|
||||
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user