[BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 (#27128)

Signed-off-by: qqma <qqma@amazon.com>
Co-authored-by: qqma <qqma@amazon.com>
This commit is contained in:
Daisy-Ma-coder 2025-10-22 12:36:39 -07:00 committed by GitHub
parent 8669c69afa
commit 5beacce2ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,10 +89,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.use_full_cuda_graph and self.fa_aot_schedule:
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
@ -114,7 +113,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.max_num_splits = 1
def _schedule_decode(
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
self,
num_reqs,
cu_query_lens,
max_query_len,
seqlens,
max_seq_len,
causal,
max_num_splits,
):
if self.fa_aot_schedule:
return get_scheduler_metadata(
@ -130,7 +136,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
num_splits=self.max_num_splits,
num_splits=max_num_splits,
)
return None
@ -148,6 +154,15 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_device.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device,
@ -155,10 +170,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
seqlens=seq_lens_device,
max_seq_len=max_seq_len,
causal=True,
max_num_splits=max_num_splits,
)
# For FA3 + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
# Ensure the persistent buffer is large enough
@ -174,13 +188,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if num_decode_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
max_num_splits = 1