mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Attention] Add MLA prefill backend: trtllm_ragged_attention_deepseek (#26397)
Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
parent
7e1d697b56
commit
0f67d4d962
@ -183,6 +183,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
||||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
||||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||||
|
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||||
VLLM_LOOPBACK_IP: str = ""
|
VLLM_LOOPBACK_IP: str = ""
|
||||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||||
@ -1250,6 +1251,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
|
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
|
||||||
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
|
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
|
||||||
),
|
),
|
||||||
|
# Controls whether to use TRT-LLM ragged DeepSeek prefill
|
||||||
|
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool(
|
||||||
|
int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0"))
|
||||||
|
),
|
||||||
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
|
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
|
||||||
# If set to 0/False, use the default attention backend in flashinfer.
|
# If set to 0/False, use the default attention backend in flashinfer.
|
||||||
# If not set, auto-detect the attention backend in flashinfer.
|
# If not set, auto-detect the attention backend in flashinfer.
|
||||||
@ -1481,6 +1486,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
|
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
|
||||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
||||||
"VLLM_USE_CUDNN_PREFILL",
|
"VLLM_USE_CUDNN_PREFILL",
|
||||||
|
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
||||||
"VLLM_USE_TRTLLM_ATTENTION",
|
"VLLM_USE_TRTLLM_ATTENTION",
|
||||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
||||||
"VLLM_ROCM_USE_AITER",
|
"VLLM_ROCM_USE_AITER",
|
||||||
|
|||||||
@ -371,6 +371,7 @@ class MLACommonPrefillMetadata:
|
|||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
max_query_len: int
|
max_query_len: int
|
||||||
chunked_context: ChunkedContextMetadata | None = None
|
chunked_context: ChunkedContextMetadata | None = None
|
||||||
|
query_seq_lens: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -386,7 +387,6 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
|
|||||||
class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
|
class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
query_seq_lens: torch.Tensor | None = None
|
|
||||||
cudnn_workspace: torch.Tensor | None = None
|
cudnn_workspace: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -457,6 +457,7 @@ def use_flashinfer_prefill() -> bool:
|
|||||||
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
|
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
|
||||||
and flashinfer_available
|
and flashinfer_available
|
||||||
and not envs.VLLM_USE_CUDNN_PREFILL
|
and not envs.VLLM_USE_CUDNN_PREFILL
|
||||||
|
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -470,6 +471,15 @@ def use_cudnn_prefill() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||||
|
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
||||||
|
return (
|
||||||
|
flashinfer_available
|
||||||
|
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
||||||
|
and current_platform.is_device_capability(100)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
||||||
# Chosen to be the same as sglang:
|
# Chosen to be the same as sglang:
|
||||||
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
|
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
|
||||||
@ -593,6 +603,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||||
self._use_fi_prefill = use_flashinfer_prefill()
|
self._use_fi_prefill = use_flashinfer_prefill()
|
||||||
|
self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
|
||||||
self.prefill_metadata_cls = (
|
self.prefill_metadata_cls = (
|
||||||
FlashInferPrefillMetadata
|
FlashInferPrefillMetadata
|
||||||
if self._use_fi_prefill
|
if self._use_fi_prefill
|
||||||
@ -613,6 +624,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
|
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._use_trtllm_ragged_prefill:
|
||||||
|
self._workspace_buffer = torch.empty(
|
||||||
|
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
|
||||||
|
)
|
||||||
|
|
||||||
if self._use_cudnn_prefill:
|
if self._use_cudnn_prefill:
|
||||||
self.cudnn_workspace = torch.empty(
|
self.cudnn_workspace = torch.empty(
|
||||||
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
||||||
@ -934,6 +950,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
)
|
)
|
||||||
prefill_metadata.cudnn_workspace = self.cudnn_workspace
|
prefill_metadata.cudnn_workspace = self.cudnn_workspace
|
||||||
|
|
||||||
|
if self._use_trtllm_ragged_prefill:
|
||||||
|
prefill_metadata.query_seq_lens = (
|
||||||
|
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
|
||||||
|
)
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
decode_metadata = self._build_decode(
|
decode_metadata = self._build_decode(
|
||||||
@ -1230,6 +1251,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
||||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
||||||
self._pad_v = False
|
self._pad_v = False
|
||||||
|
elif use_trtllm_ragged_deepseek_prefill():
|
||||||
|
logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
|
||||||
|
self._run_prefill_context_chunk = (
|
||||||
|
self._run_prefill_context_chunk_trtllm_ragged
|
||||||
|
)
|
||||||
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
|
||||||
|
self._pad_v = False
|
||||||
elif use_cudnn_prefill():
|
elif use_cudnn_prefill():
|
||||||
logger.debug_once("Using CUDNN prefill for MLA")
|
logger.debug_once("Using CUDNN prefill for MLA")
|
||||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
|
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
|
||||||
@ -1326,6 +1354,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
):
|
):
|
||||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||||
assert prefill.prefill_main is not None
|
assert prefill.prefill_main is not None
|
||||||
|
|
||||||
ret = prefill.prefill_main.run(
|
ret = prefill.prefill_main.run(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1334,7 +1363,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(ret, tuple):
|
if isinstance(ret, tuple):
|
||||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
|
||||||
return ret[0], ret[1].transpose(0, 1).contiguous()
|
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -1384,12 +1412,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||||
):
|
):
|
||||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||||
|
|
||||||
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||||
return attn_out, lse.transpose(0, 1).contiguous()
|
return attn_out, lse.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
@ -1418,6 +1448,81 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
is_cuda_graph_compatible=True,
|
is_cuda_graph_compatible=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_prefill_new_tokens_trtllm_ragged(
|
||||||
|
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||||
|
):
|
||||||
|
"""TRT-LLM ragged attention for new tokens (causal)."""
|
||||||
|
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||||
|
|
||||||
|
assert prefill.query_seq_lens is not None
|
||||||
|
|
||||||
|
ret = trtllm_ragged_attention_deepseek(
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
workspace_buffer=self._workspace_buffer,
|
||||||
|
seq_lens=prefill.query_seq_lens,
|
||||||
|
max_q_len=prefill.max_query_len,
|
||||||
|
max_kv_len=prefill.max_query_len,
|
||||||
|
bmm1_scale=self.scale,
|
||||||
|
bmm2_scale=1.0,
|
||||||
|
o_sf_scale=1.0,
|
||||||
|
batch_size=prefill.query_seq_lens.shape[0],
|
||||||
|
window_left=-1,
|
||||||
|
cum_seq_lens_q=prefill.query_start_loc,
|
||||||
|
cum_seq_lens_kv=prefill.query_start_loc,
|
||||||
|
enable_pdl=False,
|
||||||
|
is_causal=True,
|
||||||
|
return_lse=return_softmax_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(ret, tuple):
|
||||||
|
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||||
|
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _run_prefill_context_chunk_trtllm_ragged(
|
||||||
|
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||||
|
):
|
||||||
|
"""TRT-LLM ragged attention for context chunks (non-causal)."""
|
||||||
|
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||||
|
|
||||||
|
assert prefill.chunked_context is not None
|
||||||
|
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||||
|
|
||||||
|
out = torch.zeros(
|
||||||
|
q.shape[0],
|
||||||
|
q.shape[1],
|
||||||
|
v.shape[2],
|
||||||
|
device=q.device,
|
||||||
|
dtype=q.dtype,
|
||||||
|
)
|
||||||
|
self._workspace_buffer.fill_(0)
|
||||||
|
|
||||||
|
attn_out, lse = trtllm_ragged_attention_deepseek(
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
workspace_buffer=self._workspace_buffer,
|
||||||
|
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
|
||||||
|
max_q_len=prefill.max_query_len,
|
||||||
|
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
|
||||||
|
bmm1_scale=self.scale,
|
||||||
|
bmm2_scale=1.0,
|
||||||
|
o_sf_scale=1.0,
|
||||||
|
batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
|
||||||
|
window_left=-1,
|
||||||
|
cum_seq_lens_q=prefill.query_start_loc,
|
||||||
|
cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
|
||||||
|
enable_pdl=False,
|
||||||
|
is_causal=False,
|
||||||
|
return_lse=True,
|
||||||
|
out=out,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||||
|
return attn_out, lse.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
def get_layer_weight(layer):
|
def get_layer_weight(layer):
|
||||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user