mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 13:28:42 +08:00
[BugFix] Fix full-cuda-graph illegal memory access in FA3 (#20057)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
ba7ba35cda
commit
0f9e7354f5
@ -158,12 +158,13 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph and not self.aot_schedule:
|
||||
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
|
||||
"which requires FlashAttention 3.")
|
||||
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
if self.use_full_cuda_graph:
|
||||
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
|
||||
# yet. This is because the scheduler and kernel need to always use
|
||||
# the same num_splits (which acts as an upper bound with the
|
||||
# dynamic split scheduler) which is currently heuristically decided
|
||||
# by the kernel launching code.
|
||||
self.aot_schedule = False
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
@ -299,18 +300,6 @@ class FlashAttentionMetadataBuilder(
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n].copy_(scheduler_metadata,
|
||||
non_blocking=True)
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user