remove cudagraph logic from flashmla.py

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-03 13:41:49 +00:00
parent 17a7ceef27
commit 1d75a029a9
3 changed files with 14 additions and 19 deletions

View File

@ -475,7 +475,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
ubatch_id: int = 0
) -> M:
num_reqs = req_slice.stop - req_slice.start
num_tokens = token_slice.stop - token_slice.start
@ -587,7 +586,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
ubatch_id=ubatch_id
)
return self.metadata_cls(

View File

@ -63,12 +63,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
self.cg_buf_tile_scheduler_metadata = [None, None]
self.cg_buf_num_splits = [None, None]
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor, ubatch_id = 0) -> FlashMLADecodeMetadata:
assert ubatch_id < 2
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
@ -76,31 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path
)
# logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
if self.runner.full_cuda_graph:
n = num_splits.size(0)
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_num_splits[ubatch_id] is None:
self.cg_buf_num_splits[ubatch_id] = num_splits
self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata
elif n <= self.cg_buf_num_splits[ubatch_id].size(0):
assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None
if self.cg_buf_num_splits is None:
self.cg_buf_num_splits = num_splits
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
elif n <= self.cg_buf_num_splits.size(0):
assert self.cg_buf_tile_scheduler_metadata is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() ==
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata[ubatch_id].\
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id]
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits[ubatch_id].size(0)
num_splits_view = self.cg_buf_num_splits[ubatch_id][:n]
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata(

View File

@ -851,7 +851,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_query_len=max(tokens[req_slice]),
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
ubatch_id=ubid
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list