[MISC] cudagraph_capture_sizes related improvements (#26016)

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
fhl2000 2025-10-24 20:11:05 +08:00 committed by GitHub
parent 435be10db9
commit 284cc92275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 303 additions and 110 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
import pytest
@ -8,6 +9,8 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
@ -233,3 +236,73 @@ def test_resolve_operator_overload():
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
(
"cudagraph_capture_sizes",
"max_cudagraph_capture_size",
"tp_size",
"enable_sequence_parallelism",
"max_num_batched_tokens",
"use_cudagraph",
"expected_max_size",
),
[
(None, None, 1, False, 2048, True, 512),
([1, 2, 4], 4, 1, False, 2048, True, 4),
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
([1, 256], None, 1, False, 2048, 256),
([], None, 1, False, 2048, False, 0),
(None, 0, 1, False, 2048, False, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, True, 256),
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
# the list should contain at least 1 element when use cudagraph
([], None, 1, False, 2048, True, RuntimeError),
# the max capturing size should be >= 1 when use cudagraph
(None, 0, 1, False, 2048, True, RuntimeError),
],
)
def test_cudagraph_sizes_post_init(
cudagraph_capture_sizes,
max_cudagraph_capture_size,
tp_size,
enable_sequence_parallelism,
max_num_batched_tokens,
use_cudagraph,
expected_max_size,
):
ctx = nullcontext()
if isinstance(expected_max_size, Exception):
ctx = pytest.raises(expected_max_size)
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
with ctx:
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
pass_config={
"enable_sequence_parallelism": enable_sequence_parallelism,
"enable_fusion": True,
"enable_noop": True,
},
cudagraph_mode=cudagraph_mode,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
tensor_parallel_size=tp_size,
max_num_batched_tokens=max_num_batched_tokens,
compilation_config=compilation_config,
)
vllm_config = engine_args.create_engine_config()
assert (
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
)

View File

@ -154,6 +154,8 @@ class CompilationConfig:
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`max_cudagraph_capture_size`]
[vllm.config.CompilationConfig.max_cudagraph_capture_size]
- [`cudagraph_num_of_warmups`]
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- [`cudagraph_copy_inputs`]
@ -327,18 +329,16 @@ class CompilationConfig:
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
"""Whether to use cudagraph inside compilation:
- False: cudagraph inside compilation is not used.\n
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
In the vLLM V1 Engine, this flag only applies for
CompilationMode.VLLM_COMPILE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
instead.
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
_PIECEWISE instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
@ -398,8 +398,22 @@ class CompilationConfig:
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
max_capture_size: int = field(default=None, init=False) # type: ignore
"""not configurable, computed after init"""
max_cudagraph_capture_size: int | None = field(default=None)
"""The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest
size in that list (or checked for consistency if specified). If
cudagraph_capture_sizes is not specified, the list of sizes is generated
automatically following the pattern:
[1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_cudagraph_capture_size + 1, 16))
If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2,
512) by default. This voids OOM in tight memory scenarios with small
max_num_seqs, and prevents capture of many large graphs (>512) that would
greatly increase startup time with limited performance benefit.
"""
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
@ -408,7 +422,7 @@ class CompilationConfig:
)
"""optimization:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_capture_size],
since we know all keys are in a range [0, max_cudagraph_capture_size],
we can optimize it to list[int] for better lookup performance."""
# keep track of enabled and disabled custom ops
@ -672,25 +686,12 @@ class CompilationConfig:
return VllmBackend(vllm_config)
def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes."""
if self.cudagraph_capture_sizes is None:
self.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(
(
"cudagraph sizes specified by model runner"
" %s is overridden by config %s"
),
cudagraph_capture_sizes,
dedup_sizes,
)
self.cudagraph_capture_sizes = dedup_sizes
def post_init_cudagraph_sizes(self) -> None:
"""To complete the initialization after cudagraph related
configs are set. This includes:
- initialize compile_sizes
- pre-compute the mapping bs_to_padded_graph_size
"""
computed_compile_sizes = []
if self.compile_sizes is not None:
@ -708,23 +709,24 @@ class CompilationConfig:
computed_compile_sizes.append(x)
self.compile_sizes = computed_compile_sizes # type: ignore
# sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = (
self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
)
# make sure the sizes are in ascending order
self.cudagraph_capture_sizes.sort()
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)]
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when mode is

View File

@ -71,14 +71,6 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
cuda_graph_sizes: list[int] = field(default_factory=list)
"""Cuda graph capture sizes
1. if none provided, then default set to [min(max_num_seqs * 2, 512)]
2. if one value is provided, then the capture list would follow the
pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
3. more than one value (e.g. 1 2 128) is provided, then the capture list
will follow the provided list."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
@ -235,13 +227,6 @@ class SchedulerConfig:
self.long_prefill_token_threshold,
)
# NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
# This avoids OOM in tight memory scenarios with small max_num_seqs,
# and prevents capture of many large graphs (>512) that would greatly
# increase startup time with limited performance benefit.
if not self.cuda_graph_sizes:
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
if self.async_scheduling:
self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"

View File

@ -197,10 +197,10 @@ class VllmConfig:
return hash_str
def pad_for_cudagraph(self, batch_size: int) -> int:
# if batch_size > self.compilation_config.max_capture_size,
# if batch_size > self.compilation_config.max_cudagraph_capture_size,
# it should raise an IndexError.
# the caller should make sure the batch_size is within the range,
# i.e., batch_size <= self.compilation_config.max_capture_size
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size]
@staticmethod
@ -396,6 +396,9 @@ class VllmConfig:
if self.model_config is not None and self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# override related settings when enforce eager
self.compilation_config.max_cudagraph_capture_size = 0
self.compilation_config.cudagraph_capture_sizes = []
elif envs.VLLM_USE_V1:
self.compilation_config.cudagraph_num_of_warmups = 1
@ -654,11 +657,13 @@ class VllmConfig:
```python
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to max_graph_size
cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size]
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
will be the final sizes to capture cudagraph (in descending order).
will be the final sizes to capture cudagraph (in ascending order).
These sizes are used to capture and reuse CUDA graphs for
performance-critical paths (e.g., decoding). Capturing enables
@ -685,35 +690,111 @@ class VllmConfig:
not be used.
"""
# calculate the default `batch_size_capture_list`
batch_size_capture_list = []
if self.model_config is not None and not self.model_config.enforce_eager:
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
if len(cuda_graph_sizes) == 1:
max_graph_size = cuda_graph_sizes[0]
assert max_graph_size >= 1, (
"Maximum cudagraph size should be greater than or equal to 1."
if (
self.model_config is not None
and not self.model_config.enforce_eager
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
# determine the initial max_cudagraph_capture_size
max_cudagraph_capture_size = (
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512
)
batch_size_capture_list = [
i for i in [1, 2, 4] if i <= max_graph_size
] + list(range(8, max_graph_size + 1, 8))
elif len(cuda_graph_sizes) > 1:
batch_size_capture_list = sorted(cuda_graph_sizes)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
assert max_cudagraph_capture_size >= 1, (
"Maximum cudagraph size should be greater than or equal to 1 "
"when using cuda graph."
)
# determine the cudagraph_capture_sizes
if self.compilation_config.cudagraph_capture_sizes is not None:
assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
"cudagraph_capture_sizes should contain at least one element "
"when using cuda graph."
)
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
cudagraph_capture_sizes = dedup_sizes
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes.sort()
else:
raise TypeError(f"Invalid value for {cuda_graph_sizes=}.")
cudagraph_capture_sizes = [
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
]
if max_cudagraph_capture_size >= 8:
# Step size 8 for small batch sizes, up to 256(not included)
cudagraph_capture_sizes += list(
range(8, min(max_cudagraph_capture_size + 1, 256), 8)
)
if max_cudagraph_capture_size >= 256:
# Step size 16 for larger batch sizes
cudagraph_capture_sizes += list(
range(256, max_cudagraph_capture_size + 1, 16)
)
if (
self.parallel_config.tensor_parallel_size > 1
and self.compilation_config.pass_config.enable_sequence_parallelism
):
batch_size_capture_list = self.update_sizes_for_sequence_parallelism(
batch_size_capture_list
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list if size <= max_num_tokens
]
self.compilation_config.init_with_cudagraph_sizes(batch_size_capture_list)
# user-specific compilation_config.max_cudagraph_capture_size get
# truncated to valid_max_size when they are inconsistent.
valid_max_size = (
cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
)
if (
self.compilation_config.max_cudagraph_capture_size is not None
and self.compilation_config.max_cudagraph_capture_size != valid_max_size
):
# raise error only when both two flags are user-specified
# and they are inconsistent with each other
if self.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"customized max_cudagraph_capture_size"
f"(={self.compilation_config.max_cudagraph_capture_size}) "
"should be consistent with the max value of "
f"cudagraph_capture_sizes(={valid_max_size})"
)
logger.warning(
"Truncating max_cudagraph_capture_size to %d",
valid_max_size,
)
# always set the final max_cudagraph_capture_size
self.compilation_config.max_cudagraph_capture_size = valid_max_size
if self.compilation_config.cudagraph_capture_sizes is not None and len(
cudagraph_capture_sizes
) < len(self.compilation_config.cudagraph_capture_sizes):
# If users have specified capture sizes, we only need to
# compare the lens before and after modification since the modified
# list is only the subset of the original list.
logger.warning(
(
"cudagraph_capture_sizes specified in compilation_config"
" %s is overridden by config %s"
),
self.compilation_config.cudagraph_capture_sizes,
cudagraph_capture_sizes,
)
# always write back the final sizes
self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
# no cudagraph in use
self.compilation_config.max_cudagraph_capture_size = 0
self.compilation_config.cudagraph_capture_sizes = []
# complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes()
def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config

View File

@ -364,7 +364,13 @@ class EngineArgs:
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: int | None = ModelConfig.seed
max_model_len: int | None = ModelConfig.max_model_len
cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes")
cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
cudagraph_capture_sizes: list[int] | None = (
CompilationConfig.cudagraph_capture_sizes
)
max_cudagraph_capture_size: int | None = get_field(
CompilationConfig, "max_cudagraph_capture_size"
)
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
@ -1007,9 +1013,6 @@ class EngineArgs:
"--max-long-partial-prefills",
**scheduler_kwargs["max_long_partial_prefills"],
)
scheduler_group.add_argument(
"--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"]
)
scheduler_group.add_argument(
"--long-prefill-token-threshold",
**scheduler_kwargs["long_prefill_token_threshold"],
@ -1039,6 +1042,29 @@ class EngineArgs:
"--async-scheduling", **scheduler_kwargs["async_scheduling"]
)
# Compilation arguments
compilation_kwargs = get_kwargs(CompilationConfig)
compilation_group = parser.add_argument_group(
title="CompilationConfig",
description=CompilationConfig.__doc__,
)
compilation_group.add_argument(
"--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
)
compilation_kwargs["cudagraph_capture_sizes"]["help"] = (
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
" whichever is soonest. Please use --cudagraph-capture-sizes instead."
)
compilation_group.add_argument(
"--cuda-graph-sizes",
**compilation_kwargs["cudagraph_capture_sizes"],
deprecated=True,
)
compilation_group.add_argument(
"--max-cudagraph-capture-size",
**compilation_kwargs["max_cudagraph_capture_size"],
)
# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group(
@ -1548,7 +1574,6 @@ class EngineArgs:
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
cuda_graph_sizes=self.cuda_graph_sizes,
num_lookahead_slots=num_lookahead_slots,
enable_chunked_prefill=self.enable_chunked_prefill,
disable_chunked_mm_input=self.disable_chunked_mm_input,
@ -1616,6 +1641,38 @@ class EngineArgs:
collect_detailed_traces=self.collect_detailed_traces,
)
# Compilation config overrides
if self.cuda_graph_sizes is not None:
logger.warning(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
"instead."
)
if self.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"cuda_graph_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
self.compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
if self.cudagraph_capture_sizes is not None:
if self.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"cudagraph_capture_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
self.compilation_config.cudagraph_capture_sizes = (
self.cudagraph_capture_sizes
)
if self.max_cudagraph_capture_size is not None:
if self.compilation_config.max_cudagraph_capture_size is not None:
raise ValueError(
"max_cudagraph_capture_size and compilation_config."
"max_cudagraph_capture_size are mutually exclusive"
)
self.compilation_config.max_cudagraph_capture_size = (
self.max_cudagraph_capture_size
)
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,

View File

@ -185,7 +185,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.moe = moe
self.mxfp4_backend = get_mxfp4_backend()
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_capture_size
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
assert self.mxfp4_backend != Mxfp4Backend.NONE, (

View File

@ -259,21 +259,19 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
# Increase the max capture size from 512 to 992 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs
# from 67 to 81.
scheduler_config = vllm_config.scheduler_config
if len(scheduler_config.cuda_graph_sizes) == 1:
max_capture_size = scheduler_config.cuda_graph_sizes[0]
compilation_config = vllm_config.compilation_config
# Only override when the user has not set either of
# cudagraph_capture_sizes or max_cudagraph_capture_size.
if (
compilation_config.cudagraph_capture_sizes is None
and compilation_config.max_cudagraph_capture_size is None
):
# FIXME(woosuk): When using full cuda graph with FA3, the max
# supported size is 992.
if max_capture_size < 992:
cuda_graph_sizes = [1, 2, 4]
# Step size 8 for small batch sizes
cuda_graph_sizes += [i for i in range(8, 256, 8)]
# Step size 16 for larger batch sizes
cuda_graph_sizes += [i for i in range(256, 993, 16)]
scheduler_config.cuda_graph_sizes = cuda_graph_sizes
logger.info(
"Overriding max cuda graph capture size to %d for performance.", 992
)
compilation_config.max_cudagraph_capture_size = 992
logger.info(
"Overriding max cuda graph capture size to %d for performance.", 992
)
class MambaModelConfig(VerifyAndUpdateConfig):

View File

@ -236,7 +236,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_capture_size
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.aot_schedule:
if self.max_cudagraph_size > 992:

View File

@ -324,7 +324,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
] = {}
self._decode_cudagraph_max_bs = min(
(1 + num_spec_tokens) * max_num_reqs,
self.compilation_config.max_capture_size,
self.compilation_config.max_cudagraph_capture_size,
)
self.num_qo_heads = self.model_config.get_num_attention_heads(

View File

@ -87,7 +87,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
)
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
self.compilation_config.max_capture_size,
self.compilation_config.max_cudagraph_capture_size,
)
self.spec_state_indices_tensor = torch.empty(

View File

@ -36,7 +36,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size,
self.compilation_config.max_cudagraph_capture_size,
)
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),

View File

@ -89,7 +89,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_capture_size
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.fa_aot_schedule:
if self.max_cudagraph_size > 992:

View File

@ -104,7 +104,7 @@ class EagleProposer:
)
self.cudagraph_batch_sizes = (
list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
(sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes))
if self.use_cuda_graph
else []
)

View File

@ -379,16 +379,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.async_output_copy_stream = torch.cuda.Stream()
self.prepare_inputs_event = torch.cuda.Event()
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
if (
self.compilation_config.cudagraph_capture_sizes
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes)
self.cudagraph_batch_sizes = sorted(
self.compilation_config.cudagraph_capture_sizes
)
# Cache the device properties.
@ -3791,7 +3788,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
# make sure we capture the largest batch size first
compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)