[Bugfix] Fix CUDA graph selection bug in FlashInfer at high concurrency (#26499)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett 2025-10-09 17:12:34 -04:00 committed by GitHub
parent c9d33c60dc
commit 6e783bc54b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -296,6 +296,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
speculative_config = vllm_config.speculative_config
num_spec_tokens = (
speculative_config.num_speculative_tokens
if speculative_config is not None
else 0
)
self.enable_cuda_graph = (
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
)
@ -306,7 +312,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
int, BatchDecodeWithPagedKVCacheWrapper
] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size
(1 + num_spec_tokens) * max_num_reqs,
self.compilation_config.max_capture_size,
)
self.num_qo_heads = self.model_config.get_num_attention_heads(
@ -679,7 +686,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
use_cudagraph = (
self.enable_cuda_graph
and pure_decode
and num_decodes <= self._decode_cudagraph_max_bs
and num_decode_tokens <= self._decode_cudagraph_max_bs
)
if use_cudagraph:
num_input_tokens = self.vllm_config.pad_for_cudagraph(