mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 05:37:53 +08:00
[CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling (#20301)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
314af8617c
commit
8acb4badee
@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524
|
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
|||||||
@ -36,6 +36,9 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
|
||||||
|
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@ -114,6 +117,7 @@ class FlashAttentionMetadata:
|
|||||||
# Optional aot scheduling
|
# Optional aot scheduling
|
||||||
scheduler_metadata: Optional[torch.Tensor] = None
|
scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
max_num_splits: int = 0
|
||||||
|
|
||||||
# for local attention
|
# for local attention
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -158,15 +162,35 @@ class FlashAttentionMetadataBuilder(
|
|||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.block_table = block_table
|
self.block_table = block_table
|
||||||
|
|
||||||
|
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||||
if self.use_full_cuda_graph:
|
if self.use_full_cuda_graph:
|
||||||
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
|
if not self.aot_schedule:
|
||||||
# yet. This is because the scheduler and kernel need to always use
|
raise ValueError(
|
||||||
# the same num_splits (which acts as an upper bound with the
|
"AoT scheduling is required for full cuda graph.")
|
||||||
# dynamic split scheduler) which is currently heuristically decided
|
capture_sizes = compilation_config.cudagraph_capture_sizes
|
||||||
# by the kernel launching code.
|
if not capture_sizes:
|
||||||
self.aot_schedule = False
|
raise ValueError(
|
||||||
|
"cudagraph_capture_sizes should not be None when "
|
||||||
|
"full_cuda_graph is True.")
|
||||||
|
self.max_cudagraph_size = max(capture_sizes)
|
||||||
|
if self.max_cudagraph_size > 992:
|
||||||
|
# This condition derives from FA3's internal heuristic.
|
||||||
|
# TODO(woosuk): Support larger cudagraph sizes.
|
||||||
|
raise ValueError(
|
||||||
|
"Capture size larger than 992 is not supported for "
|
||||||
|
"full cuda graph.")
|
||||||
|
|
||||||
|
self.scheduler_metadata = torch.zeros(
|
||||||
|
self.runner.max_num_reqs + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.runner.device,
|
||||||
|
)
|
||||||
|
# When using cuda graph, we need to set the upper bound of the
|
||||||
|
# number of splits so that large enough intermediate buffers are
|
||||||
|
# pre-allocated during capture.
|
||||||
|
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||||
|
|
||||||
# Sliding window size to be used with the AOT scheduler will be
|
# Sliding window size to be used with the AOT scheduler will be
|
||||||
# populated on first build() call.
|
# populated on first build() call.
|
||||||
@ -226,6 +250,7 @@ class FlashAttentionMetadataBuilder(
|
|||||||
cu_seqlens_q=cu_query_lens,
|
cu_seqlens_q=cu_query_lens,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
window_size=self.aot_sliding_window,
|
window_size=self.aot_sliding_window,
|
||||||
|
num_splits=self.max_num_splits,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -302,6 +327,26 @@ class FlashAttentionMetadataBuilder(
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
causal=True)
|
causal=True)
|
||||||
|
|
||||||
|
if self.use_full_cuda_graph:
|
||||||
|
assert scheduler_metadata is not None
|
||||||
|
n = scheduler_metadata.shape[0]
|
||||||
|
self.scheduler_metadata[:n] = scheduler_metadata
|
||||||
|
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||||
|
# metadata to guarantee the correctness. Otherwise, some thread
|
||||||
|
# blocks may use the invalid scheduler metadata and overwrite the
|
||||||
|
# output buffer.
|
||||||
|
self.scheduler_metadata[n:] = 0
|
||||||
|
scheduler_metadata = self.scheduler_metadata[:n]
|
||||||
|
|
||||||
|
max_num_splits = 0
|
||||||
|
if (self.use_full_cuda_graph
|
||||||
|
and num_actual_tokens <= self.max_cudagraph_size):
|
||||||
|
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||||
|
# usage, because the intermediate buffers of size [num_splits,
|
||||||
|
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||||
|
# we only set num_splits when using cuda graphs.
|
||||||
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
attn_metadata = FlashAttentionMetadata(
|
attn_metadata = FlashAttentionMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
@ -318,6 +363,7 @@ class FlashAttentionMetadataBuilder(
|
|||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
local_attn_metadata=local_attn_metadata,
|
local_attn_metadata=local_attn_metadata,
|
||||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||||
|
max_num_splits=max_num_splits,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@ -510,6 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
q_descale=layer._q_scale.expand(descale_shape),
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
num_splits=attn_metadata.max_num_splits,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user