mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 06:35:01 +08:00
[BugFix][Attention] Fix sliding window attention in V1 giving incorrect results (#17574)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
4c33d67321
commit
0f87d8f7b2
@ -10,9 +10,11 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||||
get_flash_attn_version)
|
get_flash_attn_version)
|
||||||
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@ -273,13 +275,23 @@ def make_local_attention_virtual_batches(
|
|||||||
block_table_local
|
block_table_local
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sliding_window_configs(
|
||||||
|
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||||
|
"""Get the set of all sliding window configs used in the model."""
|
||||||
|
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
|
||||||
|
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||||
|
for layer in layers.values():
|
||||||
|
assert isinstance(layer.impl, FlashAttentionImpl)
|
||||||
|
sliding_window_configs.add(layer.impl.sliding_window)
|
||||||
|
return sliding_window_configs
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionMetadataBuilder:
|
class FlashAttentionMetadataBuilder:
|
||||||
|
|
||||||
def __init__(self, runner: "GPUModelRunner"):
|
def __init__(self, runner: "GPUModelRunner"):
|
||||||
model_config = runner.model_config
|
model_config = runner.model_config
|
||||||
|
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
|
||||||
self.num_heads_q = model_config.get_num_attention_heads(
|
self.num_heads_q = model_config.get_num_attention_heads(
|
||||||
runner.parallel_config)
|
runner.parallel_config)
|
||||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||||
@ -287,6 +299,11 @@ class FlashAttentionMetadataBuilder:
|
|||||||
self.headdim = model_config.get_head_size()
|
self.headdim = model_config.get_head_size()
|
||||||
self.page_size = self.runner.block_size
|
self.page_size = self.runner.block_size
|
||||||
|
|
||||||
|
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||||
|
# Sliding window size to be used with the AOT scheduler will be
|
||||||
|
# populated on first build() call.
|
||||||
|
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
@ -304,6 +321,22 @@ class FlashAttentionMetadataBuilder:
|
|||||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||||
self.runner.device, non_blocking=True).long()
|
self.runner.device, non_blocking=True).long()
|
||||||
|
|
||||||
|
if self.aot_sliding_window is None:
|
||||||
|
self.aot_sliding_window = (-1, -1)
|
||||||
|
# For the AOT scheduler we need the sliding window value to be
|
||||||
|
# constant for all layers to. We have to populate this on the first
|
||||||
|
# build() call so the layers are constructed (cannot populate)
|
||||||
|
# in __init__.
|
||||||
|
if self.aot_schedule:
|
||||||
|
sliding_window_configs = _get_sliding_window_configs(
|
||||||
|
self.runner.vllm_config)
|
||||||
|
if len(sliding_window_configs) == 1:
|
||||||
|
sliding_window_config = sliding_window_configs.pop()
|
||||||
|
if sliding_window_config is not None:
|
||||||
|
self.aot_sliding_window = sliding_window_config
|
||||||
|
elif len(sliding_window_configs) > 1:
|
||||||
|
self.aot_schedule = False
|
||||||
|
|
||||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||||
max_seq_len, causal):
|
max_seq_len, causal):
|
||||||
if self.aot_schedule:
|
if self.aot_schedule:
|
||||||
@ -318,6 +351,7 @@ class FlashAttentionMetadataBuilder:
|
|||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
cu_seqlens_q=cu_query_lens,
|
cu_seqlens_q=cu_query_lens,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
window_size=self.aot_sliding_window,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user