diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index ed0a35f1b460f..f67278b6693db 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -78,11 +78,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): if self.runner.full_cuda_graph: n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_num_splits is None: - self.cg_buf_num_splits = num_splits + if self.cg_buf_tile_scheduler_metadata is None: self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - elif n <= self.cg_buf_num_splits.size(0): - assert self.cg_buf_tile_scheduler_metadata is not None + self.cg_buf_num_splits = num_splits + else: + assert self.cg_buf_num_splits is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) assert (self.cg_buf_tile_scheduler_metadata.size() == @@ -93,7 +93,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) - # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough assert n <= self.cg_buf_num_splits.size(0) num_splits_view = self.cg_buf_num_splits[:n]