From 8acb4badee6421fa0dd6892396e58078d25bbb6e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 1 Jul 2025 09:07:36 -0700 Subject: [PATCH] [CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling (#20301) Signed-off-by: Woosuk Kwon --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- vllm/v1/attention/backends/flash_attn.py | 59 +++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 7b17018f65ab4..ef45a5fbebf69 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524 + GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 527b31153410b..6182b2f9b2bd4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -36,6 +36,9 @@ if TYPE_CHECKING: logger = init_logger(__name__) +# NOTE(woosuk): This is an arbitrary number. Tune it if needed. +_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 + class FlashAttentionBackend(AttentionBackend): @@ -114,6 +117,7 @@ class FlashAttentionMetadata: # Optional aot scheduling scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None + max_num_splits: int = 0 # for local attention @dataclass @@ -158,15 +162,35 @@ class FlashAttentionMetadataBuilder( self.kv_cache_spec = kv_cache_spec self.block_table = block_table + self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) self.use_full_cuda_graph = compilation_config.full_cuda_graph if self.use_full_cuda_graph: - # NOTE(lucas): AOT scheduling not supported in full cuda graph mode - # yet. This is because the scheduler and kernel need to always use - # the same num_splits (which acts as an upper bound with the - # dynamic split scheduler) which is currently heuristically decided - # by the kernel launching code. - self.aot_schedule = False + if not self.aot_schedule: + raise ValueError( + "AoT scheduling is required for full cuda graph.") + capture_sizes = compilation_config.cudagraph_capture_sizes + if not capture_sizes: + raise ValueError( + "cudagraph_capture_sizes should not be None when " + "full_cuda_graph is True.") + self.max_cudagraph_size = max(capture_sizes) + if self.max_cudagraph_size > 992: + # This condition derives from FA3's internal heuristic. + # TODO(woosuk): Support larger cudagraph sizes. + raise ValueError( + "Capture size larger than 992 is not supported for " + "full cuda graph.") + + self.scheduler_metadata = torch.zeros( + self.runner.max_num_reqs + 1, + dtype=torch.int32, + device=self.runner.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. @@ -226,6 +250,7 @@ class FlashAttentionMetadataBuilder( cu_seqlens_q=cu_query_lens, causal=causal, window_size=self.aot_sliding_window, + num_splits=self.max_num_splits, ) return None @@ -302,6 +327,26 @@ class FlashAttentionMetadataBuilder( max_seq_len=max_seq_len, causal=True) + if self.use_full_cuda_graph: + assert scheduler_metadata is not None + n = scheduler_metadata.shape[0] + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + max_num_splits = 0 + if (self.use_full_cuda_graph + and num_actual_tokens <= self.max_cudagraph_size): + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, @@ -318,6 +363,7 @@ class FlashAttentionMetadataBuilder( suffix_kv_lens=suffix_kv_lens, local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, + max_num_splits=max_num_splits, ) return attn_metadata @@ -510,6 +556,7 @@ class FlashAttentionImpl(AttentionImpl): q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, ) return output