diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 31a706475243c..763bd61834625 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -27,9 +27,11 @@ class CudaGraphManager: device: torch.device, ): self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config self.device = device self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_reqs = self.scheduler_config.max_num_seqs self.dp_size = vllm_config.parallel_config.data_parallel_size self.compilation_config = vllm_config.compilation_config assert self.compilation_config is not None @@ -39,9 +41,11 @@ class CudaGraphManager: else: self.cudagraph_mode = self.compilation_config.cudagraph_mode if self.compilation_config.cudagraph_capture_sizes is not None: - self.cudagraph_sizes = sorted( - self.compilation_config.cudagraph_capture_sizes - ) + cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) + # Limit the cudagraph sizes to the max decode batch size. + self.cudagraph_sizes = [ + x for x in cudagraph_sizes if x <= self.max_num_reqs + ] else: self.cudagraph_sizes = [] self.padded_sizes = self._init_padded_sizes() @@ -54,9 +58,10 @@ class CudaGraphManager: if not self.cudagraph_mode.has_full_cudagraphs(): # Full cuda graphs are not used. return {} + if not self.cudagraph_sizes: + return {} padded_sizes: dict[int, int] = {} - assert len(self.cudagraph_sizes) > 0 for i in range(1, self.cudagraph_sizes[-1] + 1): for x in self.cudagraph_sizes: if i <= x: