[Nvidia] Integrate SM100 cudnn prefill API to MLA prefill (#20411)

Signed-off-by: Elfie Guo <elfieg@nvidia.com>
Co-authored-by: Elfie Guo <eflieg@nvidia.com>
This commit is contained in:
Elfie Guo 2025-07-15 17:56:45 -07:00 committed by GitHub
parent 10be209493
commit 30800b01c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 5 deletions

5
vllm/envs.py Normal file → Executable file
View File

@ -139,6 +139,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_LOOPBACK_IP: str = ""
@ -962,6 +963,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
# Controls whether or not to use cudnn prefill
"VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),

113
vllm/v1/attention/backends/mla/common.py Normal file → Executable file
View File

@ -194,6 +194,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
@ -225,6 +226,8 @@ except ImportError:
try:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
from flashinfer.prefill import ( # noqa: F401
cudnn_batch_prefill_with_kv_cache)
flashinfer_available = True
except ImportError:
flashinfer_available = False
@ -236,6 +239,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
CUDNN_WORKSPACE_SIZE = 12800
class MLACommonBackend(AttentionBackend):
@ -294,6 +299,7 @@ class MLACommonPrefillMetadata:
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor
block_table: torch.Tensor
@ -309,6 +315,17 @@ class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
default_factory=list)
@dataclass
class CudnnPrefillMetadata(MLACommonPrefillMetadata):
class ChunkedContextMetadata(
MLACommonPrefillMetadata.ChunkedContextMetadata):
seq_lens: torch.Tensor
query_seq_lens: Optional[torch.Tensor] = None
cudnn_workspace: Optional[torch.Tensor] = None
@dataclass
class MLACommonDecodeMetadata:
block_table: torch.Tensor
@ -351,7 +368,8 @@ class MLACommonMetadata(Generic[D]):
decode: Optional[D] = None
prefill: Optional[Union[MLACommonPrefillMetadata,
FlashInferPrefillMetadata]] = None
FlashInferPrefillMetadata,
CudnnPrefillMetadata]] = None
def __post_init__(self):
if self.head_dim is not None:
@ -362,13 +380,19 @@ M = TypeVar("M", bound=MLACommonMetadata)
def use_flashinfer_prefill() -> bool:
if flashinfer_available:
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
# For blackwell default to flashinfer prefill if its available since
# its faster than FA2.
return current_platform.has_device_capability(100)
return False
def use_cudnn_prefill() -> bool:
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
return current_platform.has_device_capability(100)
return False
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Choosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
@ -427,11 +451,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
dtype=model_config.dtype,
device=runner.device,
)
self.block_table = block_table
self._use_cudnn_prefill = use_cudnn_prefill()
self._use_fi_prefill = use_flashinfer_prefill()
self.prefill_metadata_cls = FlashInferPrefillMetadata \
if self._use_fi_prefill else MLACommonPrefillMetadata
self.prefill_metadata_cls = (
FlashInferPrefillMetadata
if self._use_fi_prefill else CudnnPrefillMetadata
if self._use_cudnn_prefill else MLACommonPrefillMetadata)
if self._use_fi_prefill:
self._workspace_buffer = torch.empty(
@ -447,6 +475,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty(
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
dtype=torch.int8,
device=runner.device,
)
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
@ -692,15 +727,24 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata_cls = \
CudnnPrefillMetadata.ChunkedContextMetadata \
if self._use_cudnn_prefill else \
MLACommonPrefillMetadata.ChunkedContextMetadata
chunked_context_metadata = \
MLACommonPrefillMetadata.ChunkedContextMetadata(
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
if self._use_cudnn_prefill:
chunked_context_metadata.seq_lens = chunk_seq_lens
assert max(chunked_context_metadata.max_seq_lens) <= \
self.chunked_prefill_workspace_size
@ -711,6 +755,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
chunked_context=chunked_context_metadata,
)
if self._use_cudnn_prefill:
assert isinstance(prefill_metadata, CudnnPrefillMetadata)
prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \
- prefill_query_start_loc[:-1]
prefill_metadata.cudnn_workspace = self.cudnn_workspace
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = self._build_decode(
@ -794,6 +844,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self._pad_v = False
elif use_cudnn_prefill():
logger.debug_once("Using CUDNN prefill for MLA")
self._run_prefill_context_chunk = \
self._run_prefill_context_chunk_cudnn
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
self._pad_v = False
else: # Use FlashAttention
logger.debug_once("Using FlashAttention prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
@ -882,6 +938,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return_lse=return_softmax_lse,
)
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
q, k, v, return_softmax_lse):
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None
output, lse = cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
v_cache=v,
scale=self.scale,
workspace_buffer=prefill.cudnn_workspace,
max_token_per_sequence=prefill.max_query_len,
max_sequence_kv=prefill.max_query_len,
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
causal=True,
return_lse=True, # do not support False for now
is_cuda_graph_compatible=
True, #Indicates actual_seq_lens are on GPU or CPU.
)
if return_softmax_lse:
return output, lse
return output
def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v):
assert prefill.chunked_context is not None
@ -908,6 +987,30 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return_lse=True,
)
def _run_prefill_context_chunk_cudnn(self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v):
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
v_cache=v,
scale=self.scale,
workspace_buffer=prefill.cudnn_workspace,
max_token_per_sequence=prefill.max_query_len,
max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].
view(-1, 1, 1, 1),
causal=False,
return_lse=True,
is_cuda_graph_compatible=
True, #Indicates actual_seq_lens are on GPU or CPU.
)
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)