mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 18:19:22 +08:00
[CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling (#20301)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
314af8617c
commit
8acb4badee
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524
|
||||
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -36,6 +36,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
|
||||
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -114,6 +117,7 @@ class FlashAttentionMetadata:
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
@ -158,15 +162,35 @@ class FlashAttentionMetadataBuilder(
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph:
|
||||
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
|
||||
# yet. This is because the scheduler and kernel need to always use
|
||||
# the same num_splits (which acts as an upper bound with the
|
||||
# dynamic split scheduler) which is currently heuristically decided
|
||||
# by the kernel launching code.
|
||||
self.aot_schedule = False
|
||||
if not self.aot_schedule:
|
||||
raise ValueError(
|
||||
"AoT scheduling is required for full cuda graph.")
|
||||
capture_sizes = compilation_config.cudagraph_capture_sizes
|
||||
if not capture_sizes:
|
||||
raise ValueError(
|
||||
"cudagraph_capture_sizes should not be None when "
|
||||
"full_cuda_graph is True.")
|
||||
self.max_cudagraph_size = max(capture_sizes)
|
||||
if self.max_cudagraph_size > 992:
|
||||
# This condition derives from FA3's internal heuristic.
|
||||
# TODO(woosuk): Support larger cudagraph sizes.
|
||||
raise ValueError(
|
||||
"Capture size larger than 992 is not supported for "
|
||||
"full cuda graph.")
|
||||
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
self.runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
@ -226,6 +250,7 @@ class FlashAttentionMetadataBuilder(
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
window_size=self.aot_sliding_window,
|
||||
num_splits=self.max_num_splits,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -302,6 +327,26 @@ class FlashAttentionMetadataBuilder(
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
max_num_splits = 0
|
||||
if (self.use_full_cuda_graph
|
||||
and num_actual_tokens <= self.max_cudagraph_size):
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
@ -318,6 +363,7 @@ class FlashAttentionMetadataBuilder(
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -510,6 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user