mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
[Bugfix] Fix CUDA graph selection bug in FlashInfer at high concurrency (#26499)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
c9d33c60dc
commit
6e783bc54b
@ -296,6 +296,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
speculative_config = vllm_config.speculative_config
|
||||
num_spec_tokens = (
|
||||
speculative_config.num_speculative_tokens
|
||||
if speculative_config is not None
|
||||
else 0
|
||||
)
|
||||
self.enable_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
)
|
||||
@ -306,7 +312,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
int, BatchDecodeWithPagedKVCacheWrapper
|
||||
] = {}
|
||||
self._decode_cudagraph_max_bs = min(
|
||||
max_num_reqs, self.compilation_config.max_capture_size
|
||||
(1 + num_spec_tokens) * max_num_reqs,
|
||||
self.compilation_config.max_capture_size,
|
||||
)
|
||||
|
||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||
@ -679,7 +686,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
use_cudagraph = (
|
||||
self.enable_cuda_graph
|
||||
and pure_decode
|
||||
and num_decodes <= self._decode_cudagraph_max_bs
|
||||
and num_decode_tokens <= self._decode_cudagraph_max_bs
|
||||
)
|
||||
if use_cudagraph:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user