[BugFix] Potential Fix for FA3 full-cudagraph IMA (#25490)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-09-24 05:04:04 -04:00 committed by GitHub
parent 2e19a848d4
commit 2338daffd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -194,10 +194,9 @@ class FlashAttentionMetadataBuilder(
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.use_full_cuda_graph and self.aot_schedule:
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
@ -259,6 +258,15 @@ class FlashAttentionMetadataBuilder(
self.aot_schedule = False
aot_schedule = False
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
if self.use_full_cuda_graph and \
num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
cache_dtype = self.cache_config.cache_dtype
@ -281,7 +289,7 @@ class FlashAttentionMetadataBuilder(
page_size=self.block_size,
causal=causal,
window_size=self.aot_sliding_window,
num_splits=self.max_num_splits,
num_splits=max_num_splits,
)
return None
@ -322,7 +330,6 @@ class FlashAttentionMetadataBuilder(
max_seq_len=max_seq_len,
causal=causal)
# For FA3 + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
@ -333,13 +340,6 @@ class FlashAttentionMetadataBuilder(
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,