mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 13:25:30 +08:00
[Bugfix] Fix CUDA graph selection bug in FlashInfer at high concurrency (#26499)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
c9d33c60dc
commit
6e783bc54b
@ -296,6 +296,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||||
|
speculative_config = vllm_config.speculative_config
|
||||||
|
num_spec_tokens = (
|
||||||
|
speculative_config.num_speculative_tokens
|
||||||
|
if speculative_config is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
self.enable_cuda_graph = (
|
self.enable_cuda_graph = (
|
||||||
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||||
)
|
)
|
||||||
@ -306,7 +312,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
int, BatchDecodeWithPagedKVCacheWrapper
|
int, BatchDecodeWithPagedKVCacheWrapper
|
||||||
] = {}
|
] = {}
|
||||||
self._decode_cudagraph_max_bs = min(
|
self._decode_cudagraph_max_bs = min(
|
||||||
max_num_reqs, self.compilation_config.max_capture_size
|
(1 + num_spec_tokens) * max_num_reqs,
|
||||||
|
self.compilation_config.max_capture_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||||
@ -679,7 +686,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
use_cudagraph = (
|
use_cudagraph = (
|
||||||
self.enable_cuda_graph
|
self.enable_cuda_graph
|
||||||
and pure_decode
|
and pure_decode
|
||||||
and num_decodes <= self._decode_cudagraph_max_bs
|
and num_decode_tokens <= self._decode_cudagraph_max_bs
|
||||||
)
|
)
|
||||||
if use_cudagraph:
|
if use_cudagraph:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user