[CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling (#20301)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-01 09:07:36 -07:00 committed by GitHub
parent 314af8617c
commit 8acb4badee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 7 deletions

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524 GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -36,6 +36,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@ -114,6 +117,7 @@ class FlashAttentionMetadata:
# Optional aot scheduling # Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
# for local attention # for local attention
@dataclass @dataclass
@ -158,15 +162,35 @@ class FlashAttentionMetadataBuilder(
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_table = block_table self.block_table = block_table
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3) self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph: if self.use_full_cuda_graph:
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode if not self.aot_schedule:
# yet. This is because the scheduler and kernel need to always use raise ValueError(
# the same num_splits (which acts as an upper bound with the "AoT scheduling is required for full cuda graph.")
# dynamic split scheduler) which is currently heuristically decided capture_sizes = compilation_config.cudagraph_capture_sizes
# by the kernel launching code. if not capture_sizes:
self.aot_schedule = False raise ValueError(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True.")
self.max_cudagraph_size = max(capture_sizes)
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise ValueError(
"Capture size larger than 992 is not supported for "
"full cuda graph.")
self.scheduler_metadata = torch.zeros(
self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# Sliding window size to be used with the AOT scheduler will be # Sliding window size to be used with the AOT scheduler will be
# populated on first build() call. # populated on first build() call.
@ -226,6 +250,7 @@ class FlashAttentionMetadataBuilder(
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
causal=causal, causal=causal,
window_size=self.aot_sliding_window, window_size=self.aot_sliding_window,
num_splits=self.max_num_splits,
) )
return None return None
@ -302,6 +327,26 @@ class FlashAttentionMetadataBuilder(
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
causal=True) causal=True)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0
if (self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
@ -318,6 +363,7 @@ class FlashAttentionMetadataBuilder(
suffix_kv_lens=suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata, local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
) )
return attn_metadata return attn_metadata
@ -510,6 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=layer._q_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
) )
return output return output