diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index be26e0060db5e..46da28540d110 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -75,7 +75,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): 1, # MQA for the decode path ) - if self.runner.full_cuda_graph: + n = num_splits.size(0) + if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1]: # First time around (CUDAGraph capture), allocate the static buffer if self.cg_buf_tile_scheduler_metadata is None: self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata