mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 09:15:17 +08:00
[Nvidia] Integrate SM100 cudnn prefill API to MLA prefill (#20411)
Signed-off-by: Elfie Guo <elfieg@nvidia.com> Co-authored-by: Elfie Guo <eflieg@nvidia.com>
This commit is contained in:
parent
10be209493
commit
30800b01c2
5
vllm/envs.py
Normal file → Executable file
5
vllm/envs.py
Normal file → Executable file
@ -139,6 +139,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||
VLLM_LOOPBACK_IP: str = ""
|
||||
|
||||
|
||||
@ -962,6 +963,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
|
||||
|
||||
# Controls whether or not to use cudnn prefill
|
||||
"VLLM_USE_CUDNN_PREFILL":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
||||
|
||||
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
|
||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
||||
|
||||
113
vllm/v1/attention/backends/mla/common.py
Normal file → Executable file
113
vllm/v1/attention/backends/mla/common.py
Normal file → Executable file
@ -194,6 +194,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
@ -225,6 +226,8 @@ except ImportError:
|
||||
|
||||
try:
|
||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||
from flashinfer.prefill import ( # noqa: F401
|
||||
cudnn_batch_prefill_with_kv_cache)
|
||||
flashinfer_available = True
|
||||
except ImportError:
|
||||
flashinfer_available = False
|
||||
@ -236,6 +239,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CUDNN_WORKSPACE_SIZE = 12800
|
||||
|
||||
|
||||
class MLACommonBackend(AttentionBackend):
|
||||
|
||||
@ -294,6 +299,7 @@ class MLACommonPrefillMetadata:
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
seq_lens: torch.Tensor
|
||||
workspace: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
@ -309,6 +315,17 @@ class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
|
||||
default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CudnnPrefillMetadata(MLACommonPrefillMetadata):
|
||||
|
||||
class ChunkedContextMetadata(
|
||||
MLACommonPrefillMetadata.ChunkedContextMetadata):
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
query_seq_lens: Optional[torch.Tensor] = None
|
||||
cudnn_workspace: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
@ -351,7 +368,8 @@ class MLACommonMetadata(Generic[D]):
|
||||
|
||||
decode: Optional[D] = None
|
||||
prefill: Optional[Union[MLACommonPrefillMetadata,
|
||||
FlashInferPrefillMetadata]] = None
|
||||
FlashInferPrefillMetadata,
|
||||
CudnnPrefillMetadata]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_dim is not None:
|
||||
@ -362,13 +380,19 @@ M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
if flashinfer_available:
|
||||
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
|
||||
# For blackwell default to flashinfer prefill if its available since
|
||||
# its faster than FA2.
|
||||
return current_platform.has_device_capability(100)
|
||||
return False
|
||||
|
||||
|
||||
def use_cudnn_prefill() -> bool:
|
||||
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
|
||||
return current_platform.has_device_capability(100)
|
||||
return False
|
||||
|
||||
|
||||
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
||||
# Choosen to be the same as sglang:
|
||||
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
|
||||
@ -427,11 +451,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
)
|
||||
|
||||
self.block_table = block_table
|
||||
|
||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||
self._use_fi_prefill = use_flashinfer_prefill()
|
||||
self.prefill_metadata_cls = FlashInferPrefillMetadata \
|
||||
if self._use_fi_prefill else MLACommonPrefillMetadata
|
||||
self.prefill_metadata_cls = (
|
||||
FlashInferPrefillMetadata
|
||||
if self._use_fi_prefill else CudnnPrefillMetadata
|
||||
if self._use_cudnn_prefill else MLACommonPrefillMetadata)
|
||||
|
||||
if self._use_fi_prefill:
|
||||
self._workspace_buffer = torch.empty(
|
||||
@ -447,6 +475,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self._global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
self.cudnn_workspace = torch.empty(
|
||||
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
||||
dtype=torch.int8,
|
||||
device=runner.device,
|
||||
)
|
||||
|
||||
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
|
||||
qo_indptr = prefill.query_start_loc
|
||||
|
||||
@ -692,15 +727,24 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
chunked_context_metadata_cls = \
|
||||
CudnnPrefillMetadata.ChunkedContextMetadata \
|
||||
if self._use_cudnn_prefill else \
|
||||
MLACommonPrefillMetadata.ChunkedContextMetadata
|
||||
|
||||
chunked_context_metadata = \
|
||||
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
chunked_context_metadata.seq_lens = chunk_seq_lens
|
||||
|
||||
assert max(chunked_context_metadata.max_seq_lens) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
@ -711,6 +755,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
chunked_context=chunked_context_metadata,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
assert isinstance(prefill_metadata, CudnnPrefillMetadata)
|
||||
prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \
|
||||
- prefill_query_start_loc[:-1]
|
||||
prefill_metadata.cudnn_workspace = self.cudnn_workspace
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
@ -794,6 +844,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
||||
self._pad_v = False
|
||||
elif use_cudnn_prefill():
|
||||
logger.debug_once("Using CUDNN prefill for MLA")
|
||||
self._run_prefill_context_chunk = \
|
||||
self._run_prefill_context_chunk_cudnn
|
||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
|
||||
self._pad_v = False
|
||||
else: # Use FlashAttention
|
||||
logger.debug_once("Using FlashAttention prefill for MLA")
|
||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
|
||||
@ -882,6 +938,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
return_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
|
||||
q, k, v, return_softmax_lse):
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
assert prefill.query_seq_lens is not None
|
||||
output, lse = cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
v_cache=v,
|
||||
scale=self.scale,
|
||||
workspace_buffer=prefill.cudnn_workspace,
|
||||
max_token_per_sequence=prefill.max_query_len,
|
||||
max_sequence_kv=prefill.max_query_len,
|
||||
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
|
||||
actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
|
||||
causal=True,
|
||||
return_lse=True, # do not support False for now
|
||||
is_cuda_graph_compatible=
|
||||
True, #Indicates actual_seq_lens are on GPU or CPU.
|
||||
)
|
||||
if return_softmax_lse:
|
||||
return output, lse
|
||||
return output
|
||||
|
||||
def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int, q, k, v):
|
||||
assert prefill.chunked_context is not None
|
||||
@ -908,6 +987,30 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
def _run_prefill_context_chunk_cudnn(self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int, q, k, v):
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
assert prefill.chunked_context is not None
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
assert prefill.query_seq_lens is not None
|
||||
return cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
v_cache=v,
|
||||
scale=self.scale,
|
||||
workspace_buffer=prefill.cudnn_workspace,
|
||||
max_token_per_sequence=prefill.max_query_len,
|
||||
max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
|
||||
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
|
||||
actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].
|
||||
view(-1, 1, 1, 1),
|
||||
causal=False,
|
||||
return_lse=True,
|
||||
is_cuda_graph_compatible=
|
||||
True, #Indicates actual_seq_lens are on GPU or CPU.
|
||||
)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user