[Bugfix] Support full cuda graph with sliding window attention

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-03 20:28:31 -07:00
parent c2e75b3c11
commit 3e56ae2878

View File

@ -205,9 +205,11 @@ class FlashAttentionMetadataBuilder(
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
sliding_window = getattr(kv_cache_spec, "sliding_window", None)
if sliding_window is not None:
self.aot_sliding_window = (sliding_window - 1, 0)
else:
self.aot_sliding_window = (-1, -1)
def build(self,
common_prefix_len: int,
@ -231,23 +233,6 @@ class FlashAttentionMetadataBuilder(
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
aot_schedule = False
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
cache_dtype = self.cache_config.cache_dtype