From 3e56ae2878fce62b0336c1571388c395ff6296e0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 3 Aug 2025 20:28:31 -0700 Subject: [PATCH] [Bugfix] Support full cuda graph with sliding window attention Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flash_attn.py | 25 +++++------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f086bab2556eb..e1575ec827561 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -205,9 +205,11 @@ class FlashAttentionMetadataBuilder( # pre-allocated during capture. self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH - # Sliding window size to be used with the AOT scheduler will be - # populated on first build() call. - self.aot_sliding_window: Optional[tuple[int, int]] = None + sliding_window = getattr(kv_cache_spec, "sliding_window", None) + if sliding_window is not None: + self.aot_sliding_window = (sliding_window - 1, 0) + else: + self.aot_sliding_window = (-1, -1) def build(self, common_prefix_len: int, @@ -231,23 +233,6 @@ class FlashAttentionMetadataBuilder( # the overhead of the aot schedule is not worth it for spec-decode aot_schedule = self.aot_schedule and not fast_build - if self.aot_sliding_window is None: - self.aot_sliding_window = (-1, -1) - # For the AOT scheduler we need the sliding window value to be - # constant for all layers to. We have to populate this on the first - # build() call so the layers are constructed (cannot populate) - # in __init__. - if aot_schedule: - sliding_window_configs = _get_sliding_window_configs( - self.vllm_config) - if len(sliding_window_configs) == 1: - sliding_window_config = sliding_window_configs.pop() - if sliding_window_config is not None: - self.aot_sliding_window = sliding_window_config - elif len(sliding_window_configs) > 1: - self.aot_schedule = False - aot_schedule = False - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): cache_dtype = self.cache_config.cache_dtype