diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 585d3997cc3a..107df502e08e 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -580,9 +580,12 @@ class VllmConfig: not self.model_config.enforce_eager: cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes if len(cuda_graph_sizes) == 1: - batch_size_capture_list = [1, 2, 4] + [ - i for i in range(8, cuda_graph_sizes[0] + 1, 8) - ] + max_graph_size = cuda_graph_sizes[0] + assert max_graph_size >= 1, "Maximum cudagraph size should be" \ + " greater than or equal to 1." + batch_size_capture_list = [ + i for i in [1, 2, 4] if i <= max_graph_size + ] + list(range(8, max_graph_size + 1, 8)) elif len(cuda_graph_sizes) > 1: batch_size_capture_list = sorted(cuda_graph_sizes) else: