diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 4f782ef92c55..4145e84c2ee0 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from contextlib import nullcontext import pytest @@ -8,6 +9,8 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import CompilationMode +from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer @@ -233,3 +236,73 @@ def test_resolve_operator_overload(): assert len(resolved) == 2 # Only 2 valid ops assert resolved[0] is torch.ops.aten.mm.default assert resolved[1] is torch.ops.aten.addmm.default + + +@pytest.mark.skipif( + not current_platform.support_static_graph_mode(), + reason="Skip if not cudagraph mode supported", +) +@pytest.mark.parametrize( + ( + "cudagraph_capture_sizes", + "max_cudagraph_capture_size", + "tp_size", + "enable_sequence_parallelism", + "max_num_batched_tokens", + "use_cudagraph", + "expected_max_size", + ), + [ + (None, None, 1, False, 2048, True, 512), + ([1, 2, 4], 4, 1, False, 2048, True, 4), + ([1, 2, 4], 8, 1, False, 2048, True, RuntimeError), + ([1, 256], None, 1, False, 2048, 256), + ([], None, 1, False, 2048, False, 0), + (None, 0, 1, False, 2048, False, 0), + # truncated to nearest multiple of 8 or 16 + (None, 257, 1, False, 2048, True, 256), + ([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list + ([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP + ([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens + # the list should contain at least 1 element when use cudagraph + ([], None, 1, False, 2048, True, RuntimeError), + # the max capturing size should be >= 1 when use cudagraph + (None, 0, 1, False, 2048, True, RuntimeError), + ], +) +def test_cudagraph_sizes_post_init( + cudagraph_capture_sizes, + max_cudagraph_capture_size, + tp_size, + enable_sequence_parallelism, + max_num_batched_tokens, + use_cudagraph, + expected_max_size, +): + ctx = nullcontext() + if isinstance(expected_max_size, Exception): + ctx = pytest.raises(expected_max_size) + + cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE + with ctx: + compilation_config = CompilationConfig( + cudagraph_capture_sizes=cudagraph_capture_sizes, + max_cudagraph_capture_size=max_cudagraph_capture_size, + pass_config={ + "enable_sequence_parallelism": enable_sequence_parallelism, + "enable_fusion": True, + "enable_noop": True, + }, + cudagraph_mode=cudagraph_mode, + ) + engine_args = EngineArgs( + model="facebook/opt-125m", + tensor_parallel_size=tp_size, + max_num_batched_tokens=max_num_batched_tokens, + compilation_config=compilation_config, + ) + vllm_config = engine_args.create_engine_config() + + assert ( + vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size + ) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 61e73414335a..c24a94091be4 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -154,6 +154,8 @@ class CompilationConfig: - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] - [`cudagraph_capture_sizes`] [vllm.config.CompilationConfig.cudagraph_capture_sizes] + - [`max_cudagraph_capture_size`] + [vllm.config.CompilationConfig.max_cudagraph_capture_size] - [`cudagraph_num_of_warmups`] [vllm.config.CompilationConfig.cudagraph_num_of_warmups] - [`cudagraph_copy_inputs`] @@ -327,18 +329,16 @@ class CompilationConfig: more modes may be added. """ use_cudagraph: bool = True - """Whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. + """Whether to use cudagraph inside compilation: + + - False: cudagraph inside compilation is not used.\n - True: cudagraph inside compilation is used. It requires that all input buffers have fixed addresses, and all splitting ops write their outputs to input buffers. - In the vLLM V1 Engine, this flag only applies for - CompilationMode.VLLM_COMPILE (aka -O3). - Note that this is orthogonal to the cudagraph capture logic - outside of compilation. + Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE - instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND + _PIECEWISE instead. """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. @@ -398,8 +398,22 @@ class CompilationConfig: pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" - max_capture_size: int = field(default=None, init=False) # type: ignore - """not configurable, computed after init""" + max_cudagraph_capture_size: int | None = field(default=None) + """The maximum cudagraph capture size. + + If cudagraph_capture_sizes is specified, this will be set to the largest + size in that list (or checked for consistency if specified). If + cudagraph_capture_sizes is not specified, the list of sizes is generated + automatically following the pattern: + + [1, 2, 4] + list(range(8, 256, 8)) + list( + range(256, max_cudagraph_capture_size + 1, 16)) + + If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2, + 512) by default. This voids OOM in tight memory scenarios with small + max_num_seqs, and prevents capture of many large graphs (>512) that would + greatly increase startup time with limited performance benefit. + """ local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" bs_to_padded_graph_size: list[int] = field( @@ -408,7 +422,7 @@ class CompilationConfig: ) """optimization: Intuitively, bs_to_padded_graph_size should be dict[int, int]. - since we know all keys are in a range [0, max_capture_size], + since we know all keys are in a range [0, max_cudagraph_capture_size], we can optimize it to list[int] for better lookup performance.""" # keep track of enabled and disabled custom ops @@ -672,25 +686,12 @@ class CompilationConfig: return VllmBackend(vllm_config) - def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None: - """To complete the initialization of config, - we need to know the cudagraph sizes.""" - - if self.cudagraph_capture_sizes is None: - self.cudagraph_capture_sizes = cudagraph_capture_sizes - else: - # de-duplicate the sizes provided by the config - dedup_sizes = list(set(self.cudagraph_capture_sizes)) - if len(dedup_sizes) < len(self.cudagraph_capture_sizes): - logger.info( - ( - "cudagraph sizes specified by model runner" - " %s is overridden by config %s" - ), - cudagraph_capture_sizes, - dedup_sizes, - ) - self.cudagraph_capture_sizes = dedup_sizes + def post_init_cudagraph_sizes(self) -> None: + """To complete the initialization after cudagraph related + configs are set. This includes: + - initialize compile_sizes + - pre-compute the mapping bs_to_padded_graph_size + """ computed_compile_sizes = [] if self.compile_sizes is not None: @@ -708,23 +709,24 @@ class CompilationConfig: computed_compile_sizes.append(x) self.compile_sizes = computed_compile_sizes # type: ignore - # sort to make sure cudagraph capture sizes are in descending order - self.cudagraph_capture_sizes.sort(reverse=True) - self.max_capture_size = ( - self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0 - ) + # make sure the sizes are in ascending order + self.cudagraph_capture_sizes.sort() + if self.cudagraph_capture_sizes: + assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)] + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_cudagraph_capture_size + 1) + ] for end, start in zip( - self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0] + self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], + [0] + self.cudagraph_capture_sizes, ): for bs in range(start, end): if bs == start: self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end - self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when mode is diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 402c29eb641f..af47531501cf 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -71,14 +71,6 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - cuda_graph_sizes: list[int] = field(default_factory=list) - """Cuda graph capture sizes - 1. if none provided, then default set to [min(max_num_seqs * 2, 512)] - 2. if one value is provided, then the capture list would follow the - pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] - 3. more than one value (e.g. 1 2 128) is provided, then the capture list - will follow the provided list.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -235,13 +227,6 @@ class SchedulerConfig: self.long_prefill_token_threshold, ) - # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. - # This avoids OOM in tight memory scenarios with small max_num_seqs, - # and prevents capture of many large graphs (>512) that would greatly - # increase startup time with limited performance benefit. - if not self.cuda_graph_sizes: - self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] - if self.async_scheduling: self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index fa7310f13b03..472d6ed2c1df 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -197,10 +197,10 @@ class VllmConfig: return hash_str def pad_for_cudagraph(self, batch_size: int) -> int: - # if batch_size > self.compilation_config.max_capture_size, + # if batch_size > self.compilation_config.max_cudagraph_capture_size, # it should raise an IndexError. # the caller should make sure the batch_size is within the range, - # i.e., batch_size <= self.compilation_config.max_capture_size + # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] @staticmethod @@ -396,6 +396,9 @@ class VllmConfig: if self.model_config is not None and self.model_config.enforce_eager: logger.info("Cudagraph is disabled under eager mode") self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # override related settings when enforce eager + self.compilation_config.max_cudagraph_capture_size = 0 + self.compilation_config.cudagraph_capture_sizes = [] elif envs.VLLM_USE_V1: self.compilation_config.cudagraph_num_of_warmups = 1 @@ -654,11 +657,13 @@ class VllmConfig: ```python max_graph_size = min(max_num_seqs * 2, 512) - # 1, 2, 4, then multiples of 8 up to max_graph_size - cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] + # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16 + # up to max_graph_size + cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( + range(256, max_graph_size + 1, 16)) In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` - will be the final sizes to capture cudagraph (in descending order). + will be the final sizes to capture cudagraph (in ascending order). These sizes are used to capture and reuse CUDA graphs for performance-critical paths (e.g., decoding). Capturing enables @@ -685,35 +690,111 @@ class VllmConfig: not be used. """ - # calculate the default `batch_size_capture_list` - batch_size_capture_list = [] - if self.model_config is not None and not self.model_config.enforce_eager: - cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes - if len(cuda_graph_sizes) == 1: - max_graph_size = cuda_graph_sizes[0] - assert max_graph_size >= 1, ( - "Maximum cudagraph size should be greater than or equal to 1." + if ( + self.model_config is not None + and not self.model_config.enforce_eager + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + # determine the initial max_cudagraph_capture_size + max_cudagraph_capture_size = ( + self.compilation_config.max_cudagraph_capture_size + ) + if max_cudagraph_capture_size is None: + max_cudagraph_capture_size = min( + self.scheduler_config.max_num_seqs * 2, 512 ) - batch_size_capture_list = [ - i for i in [1, 2, 4] if i <= max_graph_size - ] + list(range(8, max_graph_size + 1, 8)) - elif len(cuda_graph_sizes) > 1: - batch_size_capture_list = sorted(cuda_graph_sizes) + max_num_tokens = self.scheduler_config.max_num_batched_tokens + max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size) + + assert max_cudagraph_capture_size >= 1, ( + "Maximum cudagraph size should be greater than or equal to 1 " + "when using cuda graph." + ) + + # determine the cudagraph_capture_sizes + if self.compilation_config.cudagraph_capture_sizes is not None: + assert len(self.compilation_config.cudagraph_capture_sizes) > 0, ( + "cudagraph_capture_sizes should contain at least one element " + "when using cuda graph." + ) + # de-duplicate the sizes provided by the config + dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes)) + cudagraph_capture_sizes = dedup_sizes + # sort to make sure the sizes are in ascending order + cudagraph_capture_sizes.sort() else: - raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") + cudagraph_capture_sizes = [ + i for i in [1, 2, 4] if i <= max_cudagraph_capture_size + ] + if max_cudagraph_capture_size >= 8: + # Step size 8 for small batch sizes, up to 256(not included) + cudagraph_capture_sizes += list( + range(8, min(max_cudagraph_capture_size + 1, 256), 8) + ) + if max_cudagraph_capture_size >= 256: + # Step size 16 for larger batch sizes + cudagraph_capture_sizes += list( + range(256, max_cudagraph_capture_size + 1, 16) + ) + if ( self.parallel_config.tensor_parallel_size > 1 and self.compilation_config.pass_config.enable_sequence_parallelism ): - batch_size_capture_list = self.update_sizes_for_sequence_parallelism( - batch_size_capture_list + cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( + cudagraph_capture_sizes ) - max_num_tokens = self.scheduler_config.max_num_batched_tokens - batch_size_capture_list = [ - size for size in batch_size_capture_list if size <= max_num_tokens - ] - self.compilation_config.init_with_cudagraph_sizes(batch_size_capture_list) + # user-specific compilation_config.max_cudagraph_capture_size get + # truncated to valid_max_size when they are inconsistent. + valid_max_size = ( + cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0 + ) + if ( + self.compilation_config.max_cudagraph_capture_size is not None + and self.compilation_config.max_cudagraph_capture_size != valid_max_size + ): + # raise error only when both two flags are user-specified + # and they are inconsistent with each other + if self.compilation_config.cudagraph_capture_sizes is not None: + raise ValueError( + "customized max_cudagraph_capture_size" + f"(={self.compilation_config.max_cudagraph_capture_size}) " + "should be consistent with the max value of " + f"cudagraph_capture_sizes(={valid_max_size})" + ) + + logger.warning( + "Truncating max_cudagraph_capture_size to %d", + valid_max_size, + ) + # always set the final max_cudagraph_capture_size + self.compilation_config.max_cudagraph_capture_size = valid_max_size + + if self.compilation_config.cudagraph_capture_sizes is not None and len( + cudagraph_capture_sizes + ) < len(self.compilation_config.cudagraph_capture_sizes): + # If users have specified capture sizes, we only need to + # compare the lens before and after modification since the modified + # list is only the subset of the original list. + logger.warning( + ( + "cudagraph_capture_sizes specified in compilation_config" + " %s is overridden by config %s" + ), + self.compilation_config.cudagraph_capture_sizes, + cudagraph_capture_sizes, + ) + # always write back the final sizes + self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes + + else: + # no cudagraph in use + self.compilation_config.max_cudagraph_capture_size = 0 + self.compilation_config.cudagraph_capture_sizes = [] + + # complete the remaining process. + self.compilation_config.post_init_cudagraph_sizes() def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f6357c5e1ea7..c0ea84b6e4e8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -364,7 +364,13 @@ class EngineArgs: kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: int | None = ModelConfig.seed max_model_len: int | None = ModelConfig.max_model_len - cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes") + cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes + cudagraph_capture_sizes: list[int] | None = ( + CompilationConfig.cudagraph_capture_sizes + ) + max_cudagraph_capture_size: int | None = get_field( + CompilationConfig, "max_cudagraph_capture_size" + ) # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. @@ -1007,9 +1013,6 @@ class EngineArgs: "--max-long-partial-prefills", **scheduler_kwargs["max_long_partial_prefills"], ) - scheduler_group.add_argument( - "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"] - ) scheduler_group.add_argument( "--long-prefill-token-threshold", **scheduler_kwargs["long_prefill_token_threshold"], @@ -1039,6 +1042,29 @@ class EngineArgs: "--async-scheduling", **scheduler_kwargs["async_scheduling"] ) + # Compilation arguments + compilation_kwargs = get_kwargs(CompilationConfig) + compilation_group = parser.add_argument_group( + title="CompilationConfig", + description=CompilationConfig.__doc__, + ) + compilation_group.add_argument( + "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"] + ) + compilation_kwargs["cudagraph_capture_sizes"]["help"] = ( + "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0," + " whichever is soonest. Please use --cudagraph-capture-sizes instead." + ) + compilation_group.add_argument( + "--cuda-graph-sizes", + **compilation_kwargs["cudagraph_capture_sizes"], + deprecated=True, + ) + compilation_group.add_argument( + "--max-cudagraph-capture-size", + **compilation_kwargs["max_cudagraph_capture_size"], + ) + # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) vllm_group = parser.add_argument_group( @@ -1548,7 +1574,6 @@ class EngineArgs: max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, - cuda_graph_sizes=self.cuda_graph_sizes, num_lookahead_slots=num_lookahead_slots, enable_chunked_prefill=self.enable_chunked_prefill, disable_chunked_mm_input=self.disable_chunked_mm_input, @@ -1616,6 +1641,38 @@ class EngineArgs: collect_detailed_traces=self.collect_detailed_traces, ) + # Compilation config overrides + if self.cuda_graph_sizes is not None: + logger.warning( + "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or " + "v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes " + "instead." + ) + if self.compilation_config.cudagraph_capture_sizes is not None: + raise ValueError( + "cuda_graph_sizes and compilation_config." + "cudagraph_capture_sizes are mutually exclusive" + ) + self.compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes + if self.cudagraph_capture_sizes is not None: + if self.compilation_config.cudagraph_capture_sizes is not None: + raise ValueError( + "cudagraph_capture_sizes and compilation_config." + "cudagraph_capture_sizes are mutually exclusive" + ) + self.compilation_config.cudagraph_capture_sizes = ( + self.cudagraph_capture_sizes + ) + if self.max_cudagraph_capture_size is not None: + if self.compilation_config.max_cudagraph_capture_size is not None: + raise ValueError( + "max_cudagraph_capture_size and compilation_config." + "max_cudagraph_capture_size are mutually exclusive" + ) + self.compilation_config.max_cudagraph_capture_size = ( + self.max_cudagraph_capture_size + ) + config = VllmConfig( model_config=model_config, cache_config=cache_config, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 73f911514968..96297c0c4d72 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -185,7 +185,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.moe = moe self.mxfp4_backend = get_mxfp4_backend() self.max_capture_size = ( - get_current_vllm_config().compilation_config.max_capture_size + get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) assert self.mxfp4_backend != Mxfp4Backend.NONE, ( diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index f1ec33ff3de9..d4367be1c785 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -259,21 +259,19 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig): # Increase the max capture size from 512 to 992 for performance. # NOTE(woosuk): This will increase the number of CUDA graphs # from 67 to 81. - scheduler_config = vllm_config.scheduler_config - if len(scheduler_config.cuda_graph_sizes) == 1: - max_capture_size = scheduler_config.cuda_graph_sizes[0] + compilation_config = vllm_config.compilation_config + # Only override when the user has not set either of + # cudagraph_capture_sizes or max_cudagraph_capture_size. + if ( + compilation_config.cudagraph_capture_sizes is None + and compilation_config.max_cudagraph_capture_size is None + ): # FIXME(woosuk): When using full cuda graph with FA3, the max # supported size is 992. - if max_capture_size < 992: - cuda_graph_sizes = [1, 2, 4] - # Step size 8 for small batch sizes - cuda_graph_sizes += [i for i in range(8, 256, 8)] - # Step size 16 for larger batch sizes - cuda_graph_sizes += [i for i in range(256, 993, 16)] - scheduler_config.cuda_graph_sizes = cuda_graph_sizes - logger.info( - "Overriding max cuda graph capture size to %d for performance.", 992 - ) + compilation_config.max_cudagraph_capture_size = 992 + logger.info( + "Overriding max cuda graph capture size to %d for performance.", 992 + ) class MambaModelConfig(VerifyAndUpdateConfig): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8affde914782..720fbd2c15c5 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -236,7 +236,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) - self.max_cudagraph_size = self.compilation_config.max_capture_size + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.aot_schedule: if self.max_cudagraph_size > 992: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 78fd7b3bcf73..029293d2f6dd 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -324,7 +324,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ] = {} self._decode_cudagraph_max_bs = min( (1 + num_spec_tokens) * max_num_reqs, - self.compilation_config.max_capture_size, + self.compilation_config.max_cudagraph_capture_size, ) self.num_qo_heads = self.model_config.get_num_attention_heads( diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index acfefde129f6..2ca19646911e 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -87,7 +87,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ) self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), - self.compilation_config.max_capture_size, + self.compilation_config.max_cudagraph_capture_size, ) self.spec_state_indices_tensor = torch.empty( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 5aafb9813df0..52f26a9e61ca 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -36,7 +36,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size, + self.compilation_config.max_cudagraph_capture_size, ) self.state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs,), diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 18e5908d6ef1..a6aac701b784 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -89,7 +89,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) - self.max_cudagraph_size = self.compilation_config.max_capture_size + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.fa_aot_schedule: if self.max_cudagraph_size > 992: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 635fcb2f8fb1..35c2e73e8ee2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -104,7 +104,7 @@ class EagleProposer: ) self.cudagraph_batch_sizes = ( - list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)) + (sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes)) if self.use_cuda_graph else [] ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index de63475cbc09..a08d2262f0f3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -379,16 +379,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.async_output_copy_stream = torch.cuda.Stream() self.prepare_inputs_event = torch.cuda.Event() - # TODO(woosuk): Provide an option to tune the max cudagraph batch size. - # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. - # The batch sizes in the config are in descending order. if ( self.compilation_config.cudagraph_capture_sizes and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE ): - self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes) + self.cudagraph_batch_sizes = sorted( + self.compilation_config.cudagraph_capture_sizes ) # Cache the device properties. @@ -3791,7 +3788,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - + # make sure we capture the largest batch size first compilation_cases = list( product(reversed(self.cudagraph_batch_sizes), lora_cases) )