mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 04:19:38 +08:00
[Model Runner V2] Limit cudagraph size to max decode batch size (#29221)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
1489902b53
commit
e9056056fb
@ -27,9 +27,11 @@ class CudaGraphManager:
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
@ -39,9 +41,11 @@ class CudaGraphManager:
|
||||
else:
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
||||
self.cudagraph_sizes = sorted(
|
||||
self.compilation_config.cudagraph_capture_sizes
|
||||
)
|
||||
cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
||||
# Limit the cudagraph sizes to the max decode batch size.
|
||||
self.cudagraph_sizes = [
|
||||
x for x in cudagraph_sizes if x <= self.max_num_reqs
|
||||
]
|
||||
else:
|
||||
self.cudagraph_sizes = []
|
||||
self.padded_sizes = self._init_padded_sizes()
|
||||
@ -54,9 +58,10 @@ class CudaGraphManager:
|
||||
if not self.cudagraph_mode.has_full_cudagraphs():
|
||||
# Full cuda graphs are not used.
|
||||
return {}
|
||||
if not self.cudagraph_sizes:
|
||||
return {}
|
||||
|
||||
padded_sizes: dict[int, int] = {}
|
||||
assert len(self.cudagraph_sizes) > 0
|
||||
for i in range(1, self.cudagraph_sizes[-1] + 1):
|
||||
for x in self.cudagraph_sizes:
|
||||
if i <= x:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user