diff --git a/vllm/config.py b/vllm/config.py index 864903ddc4468..09ed68bb64c8d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1865,6 +1865,13 @@ class SchedulerConfig: This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" + cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) + """Cuda graph capture sizes, default is 512. + 1. if one value is provided, then the capture list would follow the pattern: + [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] + 2. more than one value (e.g. 1 2 128) is provided, + then the capture list will follow the provided list.""" + max_num_seqs: int = None # type: ignore """Maximum number of sequences to be processed in a single iteration. @@ -4235,13 +4242,20 @@ class VllmConfig: batch_size_capture_list = [] if self.model_config is not None and \ not self.model_config.enforce_eager: - batch_size_capture_list = [1, 2, 4 - ] + [i for i in range(8, 513, 8)] + cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + batch_size_capture_list = [1, 2, 4] + [ + i for i in range(8, cuda_graph_sizes[0] + 1, 8) + ] + elif len(cuda_graph_sizes) > 1: + batch_size_capture_list = sorted(cuda_graph_sizes) + else: + raise TypeError( + f"Invalid value for {cuda_graph_sizes=}.") if self.parallel_config.tensor_parallel_size > 1 and \ self.compilation_config.pass_config.enable_sequence_parallelism: batch_size_capture_list = \ self.update_sizes_for_sequence_parallelism(batch_size_capture_list) - max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ size for size in batch_size_capture_list diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ed32be7cba593..3cafcb7c31f21 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -231,6 +231,8 @@ class EngineArgs: kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = ModelConfig.seed max_model_len: Optional[int] = ModelConfig.max_model_len + cuda_graph_sizes: list[int] = get_field(SchedulerConfig, + "cuda_graph_sizes") # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. @@ -711,6 +713,8 @@ class EngineArgs: scheduler_group.add_argument( "--max-long-partial-prefills", **scheduler_kwargs["max_long_partial_prefills"]) + scheduler_group.add_argument('--cuda-graph-sizes', + **scheduler_kwargs["cuda_graph_sizes"]) scheduler_group.add_argument( "--long-prefill-token-threshold", **scheduler_kwargs["long_prefill_token_threshold"]) @@ -1042,6 +1046,7 @@ class EngineArgs: max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, + cuda_graph_sizes=self.cuda_graph_sizes, num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill,