[V1][CUDA] Full cudagraph support for FlashInfer (#21367)

This commit is contained in:
fhl2000 2025-08-02 09:49:34 +08:00 committed by GitHub
parent 3654847db5
commit 23322431c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 376 additions and 47 deletions

View File

@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available():
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
@ -153,7 +154,9 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):

View File

@ -4,26 +4,28 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional, Union
import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
from flashinfer.decode import (_get_range_buf, get_seq_lens,
trtllm_batch_decode_with_kv_cache)
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
get_per_layer_parameters, infer_global_hyperparameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
@ -174,26 +176,66 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._decode_wrapper = None # Wrapper for decode (general shape)
self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)
self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
max_num_blocks_per_request = cdiv(
vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
self.block_table_arange = torch.arange(max_num_blocks_per_request,
# Preparing persistent buffers (device-side)
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)
# host-side buffer
pin_memory = is_pin_memory_available()
self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.block_table_arange = torch.arange(max_num_pages_per_req,
dtype=torch.int32,
device=self.device)
@ -217,8 +259,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
def _get_decode_wrapper(self,
batch_size: int,
use_cudagraph: bool = False):
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(
batch_size, None)
else:
decode_wrapper = self._decode_wrapper
if decode_wrapper is None:
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
@ -226,11 +276,32 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
else:
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
# save the decode wrapper
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
else:
self._decode_wrapper = decode_wrapper
return decode_wrapper
def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
@ -308,16 +379,44 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decodes))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
num_input_tokens].fill_(
attn_metadata.
paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[
num_decodes:num_input_tokens].fill_(1)
else:
num_input_tokens = num_decodes
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph)
if not use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode(
attn_metadata.decode_wrapper,
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
self.paged_kv_last_page_len_cpu[:num_input_tokens],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
@ -336,6 +435,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
@ -381,18 +481,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
non_blocking=True)
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
# write self.paged_kv_indices inplace
num_actual_pages = torch.sum(mask)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
torch.masked_select(block_table_tensor[:, :max_num_blocks],
mask,
out=paged_kv_indices)
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
dtype=torch.int32,
device='cpu')
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
dim=0, dtype=torch.int32)
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
torch.cumsum(block_table_bounds_cpu,
dim=0,
dtype=torch.int32,
out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
paged_kv_last_page_len_cpu = torch.where(
paged_kv_last_page_len_cpu == 0, page_size,
paged_kv_last_page_len_cpu)
# write self.paged_kv_last_page_len_cpu inplace
torch.where(paged_kv_last_page_len_cpu == 0,
torch.tensor(page_size),
paged_kv_last_page_len_cpu,
out=self.paged_kv_last_page_len_cpu[:num_reqs])
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@ -402,9 +510,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
paged_kv_last_page_len_cpu=self.
paged_kv_last_page_len_cpu[:num_reqs],
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
@ -431,6 +540,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
@ -638,3 +767,163 @@ class FlashInferImpl(AttentionImpl):
out=output[:num_decode_tokens],
)
return output_padded
def fast_plan_decode(
self, # decode wrapper
indptr_cpu: torch.Tensor,
indices: torch.Tensor,
last_page_len_cpu: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
non_blocking: bool = True,
) -> None:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
cudagraph capture/replay, while the no cudagraph version turns back
to the original plan.
using original plan after passing host-side buffers:
- only host-to-device copy of indptr and last_page_len buffers
Modifications for cudagraph:
- only host-to-device copy of indptr and last_page_len buffers.
- avoid device-to-device copy of indices buffer.
Part of the code get inspiration from the original plan from FlashInfer repo
and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
"""
# Warm up with the original plan if it is first call, and always run the
# original plan if we run for dynamic shape. For fixed shape (cudagraph),
# this warm up is to generate the _cached_module for the decode wrapper.
if not self.is_cuda_graph_enabled or \
getattr(self, "vllm_first_call", True):
self.plan(
indptr_cpu,
indices,
last_page_len_cpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode,
window_left,
logits_soft_cap,
q_data_type,
kv_data_type,
data_type,
sm_scale,
rope_scale,
rope_theta,
non_blocking,
)
self.vllm_first_call = False
return
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
batch_size = len(last_page_len_cpu)
if logits_soft_cap is None:
logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
q_data_type = getattr(torch, q_data_type) if isinstance(
q_data_type, str) else q_data_type
kv_data_type = getattr(torch, kv_data_type) if isinstance(
kv_data_type, str) else kv_data_type
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime "
"batch size {} mismatches the batch size set during "
"initialization {}".format(batch_size, self._fixed_batch_size))
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the "
"allocated buffer")
# host-to-device copy for the indptr buffer
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
# host-to-device copy for the last_page_len buffer
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
non_blocking=True)
indptr_host = indptr_cpu
last_page_len_host = last_page_len_cpu
if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host,
page_size)
try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
else:
try:
# Make sure we pass exactly 15 arguments for standard version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
torch.empty(0, dtype=q_data_type),
torch.empty(0, dtype=kv_data_type),
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta

View File

@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):

View File

@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
# yapf: enable
@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):

View File

@ -18,7 +18,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
@ -57,7 +58,8 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
@ -65,9 +66,24 @@ class CommonAttentionMetadata:
M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],

View File

@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.device,
)
if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_backend_i, attn_metadata_builder_i
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

View File

@ -321,11 +321,16 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
not self.model_config.enforce_eager
# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True,
)
if self.model_runner.is_pooling_model: