[V1][CUDA] Full cudagraph support for FlashInfer (#21367)

This commit is contained in:
fhl2000 2025-08-02 09:49:34 +08:00 committed by GitHub
parent 3654847db5
commit 23322431c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 376 additions and 47 deletions

View File

@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available():
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
get_kv_cache_layout) get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@ -153,7 +154,9 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):

View File

@ -4,26 +4,28 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, ClassVar, Optional, Union
import torch import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper) MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import trtllm_batch_decode_with_kv_cache from flashinfer.decode import (_get_range_buf, get_seq_lens,
trtllm_batch_decode_with_kv_cache)
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_decode_attention from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
get_per_layer_parameters, infer_global_hyperparameters, get_kv_cache_layout, get_per_layer_parameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING: if TYPE_CHECKING:
@ -174,26 +176,66 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.device = device self.device = device
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode self._decode_wrapper = None # Wrapper for decode (general shape)
self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)
self._cascade_wrapper = None # Wrapper for cascade attention self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers # Global hyperparameters shared by all attention layers
self.global_hyperparameters = infer_global_hyperparameters( self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
self.vllm_config = vllm_config # Preparing persistent buffers (device-side)
self.cache_config = vllm_config.cache_config self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
self.kv_cache_spec = kv_cache_spec dtype=torch.int32,
max_num_blocks_per_request = cdiv( device=self.device)
vllm_config.model_config.max_model_len, self.paged_kv_indices = torch.zeros(
self.kv_cache_spec.block_size) max_num_pages, # max num pages possible
self.block_table_arange = torch.arange(max_num_blocks_per_request, dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)
# host-side buffer
pin_memory = is_pin_memory_available()
self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.block_table_arange = torch.arange(max_num_pages_per_req,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
@ -217,8 +259,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._get_workspace_buffer(), get_kv_cache_layout()) self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper return self._prefill_wrapper
def _get_decode_wrapper(self): def _get_decode_wrapper(self,
if self._decode_wrapper is None: batch_size: int,
use_cudagraph: bool = False):
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(
batch_size, None)
else:
decode_wrapper = self._decode_wrapper
if decode_wrapper is None:
num_qo_heads = ( num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads( self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config)) self.vllm_config.parallel_config))
@ -226,11 +276,32 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.vllm_config.parallel_config) self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4) num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
else:
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self._get_workspace_buffer(),
get_kv_cache_layout(), get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores) use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
# save the decode wrapper
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
else:
self._decode_wrapper = decode_wrapper
return decode_wrapper
def _get_cascade_wrapper(self): def _get_cascade_wrapper(self):
if self._cascade_wrapper is None: if self._cascade_wrapper is None:
@ -308,16 +379,44 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
if num_decodes > 0: if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper() pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decodes))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
num_input_tokens].fill_(
attn_metadata.
paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[
num_decodes:num_input_tokens].fill_(1)
else:
num_input_tokens = num_decodes
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph)
if not use_trtllm_decode_attention( if not use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len, num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype, self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim): attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan( # Use the persistent buffer with padding length,
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1], # instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode(
attn_metadata.decode_wrapper,
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
attn_metadata.paged_kv_indices, attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes], self.paged_kv_last_page_len_cpu[:num_input_tokens],
attn_metadata.num_qo_heads, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim, attn_metadata.head_dim,
@ -336,6 +435,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata: fast_build: bool = False) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata) split_decodes_and_prefills(common_attn_metadata)
@ -381,18 +481,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
non_blocking=True) non_blocking=True)
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
< block_table_bounds.unsqueeze(1)) < block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask] # write self.paged_kv_indices inplace
num_actual_pages = torch.sum(mask)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
torch.masked_select(block_table_tensor[:, :max_num_blocks],
mask,
out=paged_kv_indices)
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1, # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
torch.cumsum(block_table_bounds_cpu,
dim=0,
dtype=torch.int32, dtype=torch.int32,
device='cpu') out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
dim=0, dtype=torch.int32)
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
paged_kv_last_page_len_cpu = torch.where( # write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_cpu == 0, page_size, torch.where(paged_kv_last_page_len_cpu == 0,
paged_kv_last_page_len_cpu) torch.tensor(page_size),
paged_kv_last_page_len_cpu,
out=self.paged_kv_last_page_len_cpu[:num_reqs])
cache_dtype = self.cache_config.cache_dtype cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"): if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@ -402,9 +510,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata = FlashInferMetadata( attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
paged_kv_indptr_cpu=paged_kv_indptr_cpu, paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
paged_kv_indices=paged_kv_indices, paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, paged_kv_last_page_len_cpu=self.
paged_kv_last_page_len_cpu[:num_reqs],
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config), self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads, num_kv_heads=self.kv_cache_spec.num_kv_heads,
@ -431,6 +540,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return attn_metadata return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting # TODO: The cascade wrapper currently does not support setting
@ -638,3 +767,163 @@ class FlashInferImpl(AttentionImpl):
out=output[:num_decode_tokens], out=output[:num_decode_tokens],
) )
return output_padded return output_padded
def fast_plan_decode(
self, # decode wrapper
indptr_cpu: torch.Tensor,
indices: torch.Tensor,
last_page_len_cpu: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
non_blocking: bool = True,
) -> None:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
cudagraph capture/replay, while the no cudagraph version turns back
to the original plan.
using original plan after passing host-side buffers:
- only host-to-device copy of indptr and last_page_len buffers
Modifications for cudagraph:
- only host-to-device copy of indptr and last_page_len buffers.
- avoid device-to-device copy of indices buffer.
Part of the code get inspiration from the original plan from FlashInfer repo
and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
"""
# Warm up with the original plan if it is first call, and always run the
# original plan if we run for dynamic shape. For fixed shape (cudagraph),
# this warm up is to generate the _cached_module for the decode wrapper.
if not self.is_cuda_graph_enabled or \
getattr(self, "vllm_first_call", True):
self.plan(
indptr_cpu,
indices,
last_page_len_cpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode,
window_left,
logits_soft_cap,
q_data_type,
kv_data_type,
data_type,
sm_scale,
rope_scale,
rope_theta,
non_blocking,
)
self.vllm_first_call = False
return
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
batch_size = len(last_page_len_cpu)
if logits_soft_cap is None:
logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
q_data_type = getattr(torch, q_data_type) if isinstance(
q_data_type, str) else q_data_type
kv_data_type = getattr(torch, kv_data_type) if isinstance(
kv_data_type, str) else kv_data_type
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime "
"batch size {} mismatches the batch size set during "
"initialization {}".format(batch_size, self._fixed_batch_size))
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the "
"allocated buffer")
# host-to-device copy for the indptr buffer
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
# host-to-device copy for the last_page_len buffer
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
non_blocking=True)
indptr_host = indptr_cpu
last_page_len_host = last_page_len_cpu
if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host,
page_size)
try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
else:
try:
# Make sure we pass exactly 15 arguments for standard version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
torch.empty(0, dtype=q_data_type),
torch.empty(0, dtype=kv_data_type),
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta

View File

@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder) MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):

View File

@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder) MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
# yapf: enable # yapf: enable
@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):

View File

@ -18,7 +18,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@ -57,7 +58,8 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder( class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]): AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc import abc
import enum
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, make_dataclass from dataclasses import dataclass, make_dataclass
@ -65,9 +66,24 @@ class CommonAttentionMetadata:
M = TypeVar("M") M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]): class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention. # Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
@abstractmethod @abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],

View File

@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available, round_up, supports_dynamo) is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata, make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches) make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.device, self.device,
) )
if (self.full_cuda_graph if self.full_cuda_graph:
and not attn_metadata_builder_i.full_cudagraph_supported): if attn_metadata_builder_i.attn_cudagraph_support == \
raise ValueError( AttentionCGSupport.NEVER:
f"Full CUDAGraph not supported for " raise ValueError(f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig." f"{attn_backend_i.__name__}. Turn off "
f"full_cuda_graph or use a different attention backend.") f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_backend_i, attn_metadata_builder_i return attn_backend_i, attn_metadata_builder_i
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

View File

@ -321,11 +321,16 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs, max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
not self.model_config.enforce_eager
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \ hidden_states, last_hidden_states = \
self.model_runner._dummy_run( self.model_runner._dummy_run(
num_tokens=max_num_reqs, num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True, skip_eplb=True,
) )
if self.model_runner.is_pooling_model: if self.model_runner.is_pooling_model: