[BugFix] Fix full-cuda-graph illegal memory access in FA3 (#20057)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-06-25 04:39:04 -04:00 committed by GitHub
parent ba7ba35cda
commit 0f9e7354f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -158,12 +158,13 @@ class FlashAttentionMetadataBuilder(
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph and not self.aot_schedule:
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
"which requires FlashAttention 3.")
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device)
if self.use_full_cuda_graph:
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
# yet. This is because the scheduler and kernel need to always use
# the same num_splits (which acts as an upper bound with the
# dynamic split scheduler) which is currently heuristically decided
# by the kernel launching code.
self.aot_schedule = False
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
@ -299,18 +300,6 @@ class FlashAttentionMetadataBuilder(
max_seq_len=max_seq_len,
causal=True)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n].copy_(scheduler_metadata,
non_blocking=True)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,