[BugFix][Attention] Fix sliding window attention in V1 giving incorrect results (#17574)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-05-02 14:01:38 -04:00 committed by GitHub
parent 4c33d67321
commit 0f87d8f7b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,9 +10,11 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
@ -273,13 +275,23 @@ def make_local_attention_virtual_batches(
block_table_local
def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
"""Get the set of all sliding window configs used in the model."""
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, FlashAttentionImpl)
sliding_window_configs.add(layer.impl.sliding_window)
return sliding_window_configs
class FlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner"):
model_config = runner.model_config
self.runner = runner
self.aot_schedule = (get_flash_attn_version() == 3)
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
@ -287,6 +299,11 @@ class FlashAttentionMetadataBuilder:
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size
self.aot_schedule = (get_flash_attn_version() == 3)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
@ -304,6 +321,22 @@ class FlashAttentionMetadataBuilder:
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if self.aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.runner.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
@ -318,6 +351,7 @@ class FlashAttentionMetadataBuilder:
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
)
return None