[CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling (#20301)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-01 09:07:36 -07:00 committed by GitHub
parent 314af8617c
commit 8acb4badee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 7 deletions

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -36,6 +36,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
class FlashAttentionBackend(AttentionBackend):
@ -114,6 +117,7 @@ class FlashAttentionMetadata:
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
# for local attention
@dataclass
@ -158,15 +162,35 @@ class FlashAttentionMetadataBuilder(
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph:
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
# yet. This is because the scheduler and kernel need to always use
# the same num_splits (which acts as an upper bound with the
# dynamic split scheduler) which is currently heuristically decided
# by the kernel launching code.
self.aot_schedule = False
if not self.aot_schedule:
raise ValueError(
"AoT scheduling is required for full cuda graph.")
capture_sizes = compilation_config.cudagraph_capture_sizes
if not capture_sizes:
raise ValueError(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True.")
self.max_cudagraph_size = max(capture_sizes)
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise ValueError(
"Capture size larger than 992 is not supported for "
"full cuda graph.")
self.scheduler_metadata = torch.zeros(
self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
@ -226,6 +250,7 @@ class FlashAttentionMetadataBuilder(
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
num_splits=self.max_num_splits,
)
return None
@ -302,6 +327,26 @@ class FlashAttentionMetadataBuilder(
max_seq_len=max_seq_len,
causal=True)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0
if (self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
@ -318,6 +363,7 @@ class FlashAttentionMetadataBuilder(
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
)
return attn_metadata
@ -510,6 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
)
return output