diff --git a/vllm/envs.py b/vllm/envs.py old mode 100644 new mode 100755 index 37dd8146c060b..502978c768515 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -139,6 +139,7 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 + VLLM_USE_CUDNN_PREFILL: bool = False VLLM_LOOPBACK_IP: str = "" @@ -962,6 +963,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")), + # Controls whether or not to use cudnn prefill + "VLLM_USE_CUDNN_PREFILL": + lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), + # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. "VLLM_USE_TRTLLM_DECODE_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py old mode 100644 new mode 100755 index 904b6081d9222..381a92a830932 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -194,6 +194,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, @@ -225,6 +226,8 @@ except ImportError: try: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper + from flashinfer.prefill import ( # noqa: F401 + cudnn_batch_prefill_with_kv_cache) flashinfer_available = True except ImportError: flashinfer_available = False @@ -236,6 +239,8 @@ if TYPE_CHECKING: logger = init_logger(__name__) +CUDNN_WORKSPACE_SIZE = 12800 + class MLACommonBackend(AttentionBackend): @@ -294,6 +299,7 @@ class MLACommonPrefillMetadata: starts: torch.Tensor seq_tot: list[int] max_seq_lens: list[int] + seq_lens: torch.Tensor workspace: torch.Tensor block_table: torch.Tensor @@ -309,6 +315,17 @@ class FlashInferPrefillMetadata(MLACommonPrefillMetadata): default_factory=list) +@dataclass +class CudnnPrefillMetadata(MLACommonPrefillMetadata): + + class ChunkedContextMetadata( + MLACommonPrefillMetadata.ChunkedContextMetadata): + seq_lens: torch.Tensor + + query_seq_lens: Optional[torch.Tensor] = None + cudnn_workspace: Optional[torch.Tensor] = None + + @dataclass class MLACommonDecodeMetadata: block_table: torch.Tensor @@ -351,7 +368,8 @@ class MLACommonMetadata(Generic[D]): decode: Optional[D] = None prefill: Optional[Union[MLACommonPrefillMetadata, - FlashInferPrefillMetadata]] = None + FlashInferPrefillMetadata, + CudnnPrefillMetadata]] = None def __post_init__(self): if self.head_dim is not None: @@ -362,13 +380,19 @@ M = TypeVar("M", bound=MLACommonMetadata) def use_flashinfer_prefill() -> bool: - if flashinfer_available: + if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL: # For blackwell default to flashinfer prefill if its available since # its faster than FA2. return current_platform.has_device_capability(100) return False +def use_cudnn_prefill() -> bool: + if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL: + return current_platform.has_device_capability(100) + return False + + # Currently 394MB, this can be tuned based on GEMM sizes used. # Choosen to be the same as sglang: # https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 @@ -427,11 +451,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): dtype=model_config.dtype, device=runner.device, ) + self.block_table = block_table + self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() - self.prefill_metadata_cls = FlashInferPrefillMetadata \ - if self._use_fi_prefill else MLACommonPrefillMetadata + self.prefill_metadata_cls = ( + FlashInferPrefillMetadata + if self._use_fi_prefill else CudnnPrefillMetadata + if self._use_cudnn_prefill else MLACommonPrefillMetadata) if self._use_fi_prefill: self._workspace_buffer = torch.empty( @@ -447,6 +475,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self._global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(runner.vllm_config, MLACommonImpl)) + if self._use_cudnn_prefill: + self.cudnn_workspace = torch.empty( + CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, + dtype=torch.int8, + device=runner.device, + ) + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -692,15 +727,24 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + chunked_context_metadata_cls = \ + CudnnPrefillMetadata.ChunkedContextMetadata \ + if self._use_cudnn_prefill else \ + MLACommonPrefillMetadata.ChunkedContextMetadata + chunked_context_metadata = \ - MLACommonPrefillMetadata.ChunkedContextMetadata( + chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, ) + if self._use_cudnn_prefill: + chunked_context_metadata.seq_lens = chunk_seq_lens + assert max(chunked_context_metadata.max_seq_lens) <= \ self.chunked_prefill_workspace_size @@ -711,6 +755,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): chunked_context=chunked_context_metadata, ) + if self._use_cudnn_prefill: + assert isinstance(prefill_metadata, CudnnPrefillMetadata) + prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \ + - prefill_query_start_loc[:-1] + prefill_metadata.cudnn_workspace = self.cudnn_workspace + decode_metadata = None if self._num_decodes > 0: decode_metadata = self._build_decode( @@ -794,6 +844,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi self._pad_v = False + elif use_cudnn_prefill(): + logger.debug_once("Using CUDNN prefill for MLA") + self._run_prefill_context_chunk = \ + self._run_prefill_context_chunk_cudnn + self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn + self._pad_v = False else: # Use FlashAttention logger.debug_once("Using FlashAttention prefill for MLA") self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa @@ -882,6 +938,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return_lse=return_softmax_lse, ) + def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, + q, k, v, return_softmax_lse): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.query_seq_lens is not None + output, lse = cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.max_query_len, + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), + causal=True, + return_lse=True, # do not support False for now + is_cuda_graph_compatible= + True, #Indicates actual_seq_lens are on GPU or CPU. + ) + if return_softmax_lse: + return output, lse + return output + def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v): assert prefill.chunked_context is not None @@ -908,6 +987,30 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return_lse=True, ) + def _run_prefill_context_chunk_cudnn(self, + prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + assert prefill.query_seq_lens is not None + return cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx]. + view(-1, 1, 1, 1), + causal=False, + return_lse=True, + is_cuda_graph_compatible= + True, #Indicates actual_seq_lens are on GPU or CPU. + ) + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)