From 7e2ff2620e1d47ee281800e01ae06e1de9d53011 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 3 Jul 2025 13:45:07 +0000 Subject: [PATCH] cleanup flashmla.py Signed-off-by: Sage Moore --- vllm/v1/attention/backends/mla/flashmla.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index ed0a35f1b460f..f67278b6693db 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -78,11 +78,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): if self.runner.full_cuda_graph: n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_num_splits is None: - self.cg_buf_num_splits = num_splits + if self.cg_buf_tile_scheduler_metadata is None: self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - elif n <= self.cg_buf_num_splits.size(0): - assert self.cg_buf_tile_scheduler_metadata is not None + self.cg_buf_num_splits = num_splits + else: + assert self.cg_buf_num_splits is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) assert (self.cg_buf_tile_scheduler_metadata.size() == @@ -93,7 +93,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) - # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough assert n <= self.cg_buf_num_splits.size(0) num_splits_view = self.cg_buf_num_splits[:n]