cleanup flashmla.py

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-03 13:45:07 +00:00
parent 1d75a029a9
commit 7e2ff2620e

View File

@ -78,11 +78,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
if self.runner.full_cuda_graph:
n = num_splits.size(0)
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_num_splits is None:
self.cg_buf_num_splits = num_splits
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
elif n <= self.cg_buf_num_splits.size(0):
assert self.cg_buf_tile_scheduler_metadata is not None
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
@ -93,7 +93,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]