mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 06:55:01 +08:00
[Bugfix] Fix cuda graph sizes when running with speculative decoding (#30330)
Signed-off-by: Patryk Saffer <patryk.saffer99@gmail.com> Signed-off-by: PatrykSaffer <patryk.saffer@mistral.ai> Co-authored-by: Patryk Saffer <patryk.saffer99@gmail.com>
This commit is contained in:
parent
03b5f940fd
commit
4c2e10ea19
@ -1047,8 +1047,14 @@ class VllmConfig:
|
|||||||
self.compilation_config.max_cudagraph_capture_size
|
self.compilation_config.max_cudagraph_capture_size
|
||||||
)
|
)
|
||||||
if max_cudagraph_capture_size is None:
|
if max_cudagraph_capture_size is None:
|
||||||
|
decode_query_len = 1
|
||||||
|
if (
|
||||||
|
self.speculative_config
|
||||||
|
and self.speculative_config.num_speculative_tokens
|
||||||
|
):
|
||||||
|
decode_query_len += self.speculative_config.num_speculative_tokens
|
||||||
max_cudagraph_capture_size = min(
|
max_cudagraph_capture_size = min(
|
||||||
self.scheduler_config.max_num_seqs * 2, 512
|
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
|
||||||
)
|
)
|
||||||
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
|
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user