mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[V1] Add sliding window support to Flex Attention backend (#24089)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
7ed82d1974
commit
cf56cf78b4
@ -1,15 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 attention backends without GPUModelRunner dependency."""
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
set_kv_cache_layout)
|
||||
@ -183,13 +188,19 @@ class MockAttentionLayer:
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
|
||||
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str], vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor) -> torch.Tensor:
|
||||
def run_attention_backend(
|
||||
backend: _Backend,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
|
||||
# Handle special case for FLEX_ATTENTION_SLOW
|
||||
@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sliding_window=sliding_window,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||
"single_decode", "single_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
def _test_backend_correctness(
|
||||
batch_spec: BatchSpec,
|
||||
model: str,
|
||||
backend_to_test: list[Union[_Backend, str]],
|
||||
mask_mod,
|
||||
*,
|
||||
block_size: int = 16,
|
||||
atol: float = 1e-2,
|
||||
rtol: float = 1e-2,
|
||||
):
|
||||
"""
|
||||
Test that all backends produce similar outputs to a reference implementation
|
||||
using torch.nn.functional.scaled_dot_product_attention.
|
||||
@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
simulated paged KV cache.
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
current_platform.seed_everything(42)
|
||||
vllm_config = create_vllm_config(model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=8192)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
sliding_window = vllm_config.model_config.get_sliding_window()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
# Create causal mask: query token i attends to positions 0 to
|
||||
# (context_len + i)
|
||||
kv_len = s_len
|
||||
offset = context_len
|
||||
attn_mask = torch.full((q_len, kv_len),
|
||||
float('-inf'),
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
for i in range(q_len):
|
||||
attn_mask[i, :offset + i + 1] = 0.0
|
||||
|
||||
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
attn_mask=attn_mask,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
# Convert back to (L, H, D)
|
||||
final_mask_mod = partial(mask_mod, context_len=context_len)
|
||||
block_mask = create_block_mask(final_mask_mod,
|
||||
B=None,
|
||||
H=None,
|
||||
Q_LEN=q_len,
|
||||
KV_LEN=kv_len,
|
||||
device=device)
|
||||
sdpa_out_i = flex_attention(q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
|
||||
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
|
||||
|
||||
# Inputs for vLLM backends are just the new tokens
|
||||
@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
# 4. Run vLLM backends and compare
|
||||
# Note: flex_attention has known Triton kernel compatibility issues
|
||||
# with test infrastructures
|
||||
for backend_name in BACKENDS_TO_TEST:
|
||||
for backend_name in backend_to_test:
|
||||
# FlashAttentionm + FlexAttention:
|
||||
# [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
# FlashInfer:
|
||||
@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
2, 3).contiguous().transpose(2, 3)
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
||||
["placeholder"], vllm_config,
|
||||
device, common_attn_metadata,
|
||||
query_vllm, key_vllm,
|
||||
value_vllm,
|
||||
kv_cache_for_backend)
|
||||
backend_output = run_attention_backend(
|
||||
backend_name,
|
||||
kv_cache_spec,
|
||||
["placeholder"],
|
||||
vllm_config,
|
||||
device,
|
||||
common_attn_metadata,
|
||||
query_vllm,
|
||||
key_vllm,
|
||||
value_vllm,
|
||||
kv_cache_for_backend,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_output.shape, (
|
||||
@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
f"[{backend_name}] produced non-finite values")
|
||||
|
||||
# Check numerical similarity
|
||||
rtol = 1e-2
|
||||
atol = 5e-3
|
||||
def error_msg(msg: str, backend_name: str):
|
||||
return (f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"{msg}")
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - sdpa_output) /
|
||||
torch.abs(sdpa_output)).item()
|
||||
all_close = torch.allclose(backend_output,
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_output,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
atol=atol,
|
||||
msg=partial(error_msg,
|
||||
backend_name=backend_name))
|
||||
|
||||
assert all_close, (
|
||||
f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||
"single_decode", "single_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||
"""Test backend's correctness with causal attention."""
|
||||
|
||||
def causal_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
*,
|
||||
context_len: int,
|
||||
):
|
||||
return (q_idx + context_len) >= kv_idx
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||
causal_mask_mod)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
causal_mask_mod,
|
||||
block_size=128)
|
||||
|
||||
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_medium", "large_decode",
|
||||
"large_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
|
||||
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
"""Test backend's correctness with sliding window attention."""
|
||||
|
||||
def sliding_window_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
*,
|
||||
context_len: int,
|
||||
sliding_window: int,
|
||||
):
|
||||
causal_mask = q_idx + context_len >= kv_idx
|
||||
window_mask = q_idx + context_len - kv_idx < sliding_window
|
||||
return causal_mask & window_mask
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
model_config = ModelConfig(model=model,
|
||||
max_model_len=max(batch_spec.seq_lens))
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
|
||||
if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
block_size=128)
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
import torch._dynamo.decorators
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
_score_mod_signature,
|
||||
_score_mod_signature, and_masks,
|
||||
create_block_mask,
|
||||
flex_attention)
|
||||
|
||||
@ -292,6 +292,7 @@ class FlexAttentionMetadata:
|
||||
q_block_size: int = 16
|
||||
kv_block_size: int = 16
|
||||
transformed_score_mod: Optional[_score_mod_signature] = None
|
||||
sliding_window: Optional[int] = None
|
||||
|
||||
def _convert_physical_to_logical(
|
||||
self,
|
||||
@ -380,6 +381,53 @@ class FlexAttentionMetadata:
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the sliding window mask_mod function for FlexAttention.
|
||||
|
||||
Note that the sliding window mask here is bidirectional, we need
|
||||
to mask it with the bidirectional/causal mask for encoder/decoder.
|
||||
"""
|
||||
|
||||
if self.sliding_window is None:
|
||||
raise ValueError(
|
||||
"sliding_window must be set for sliding window attention")
|
||||
|
||||
def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor,
|
||||
q_idx: torch.Tensor, kv_idx: torch.Tensor):
|
||||
return torch.abs(q_idx - kv_idx) < self.sliding_window
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx,
|
||||
logical_kv_idx) = self._convert_physical_to_logical(
|
||||
self.doc_ids, q_idx, physical_kv_idx)
|
||||
return torch.where(
|
||||
is_valid,
|
||||
sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod if self.causal else sliding_window_mask_mod
|
||||
|
||||
def get_mask_mod(self):
|
||||
# Stage-1: initialize the base mask_mod
|
||||
# (causal mask for decoder or bidirectional mask for encoder)
|
||||
if self.causal:
|
||||
mask_mod = self.get_causal_mask_mod()
|
||||
else:
|
||||
mask_mod = self.get_bidirectional_mask_mod()
|
||||
# stage-2: add external mask_mod for special attention during
|
||||
# forwarding runtime to create the combined mask_mod.
|
||||
if self.sliding_window is not None:
|
||||
# Add sliding window mask for sliding window attention
|
||||
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
|
||||
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
|
||||
return mask_mod
|
||||
|
||||
def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
|
||||
"""Creates the transformed score_mod function for FlexAttention.
|
||||
|
||||
@ -472,12 +520,9 @@ class FlexAttentionMetadata:
|
||||
return BlockMask.from_kv_blocks(**block_mask_kwargs)
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
if self.causal:
|
||||
mask_mod = self.get_causal_mask_mod()
|
||||
kv_len = self.total_cache_tokens
|
||||
else:
|
||||
mask_mod = self.get_bidirectional_mask_mod()
|
||||
kv_len = self.num_actual_tokens
|
||||
mask_mod = self.get_mask_mod()
|
||||
kv_len = (self.total_cache_tokens
|
||||
if self.causal else self.num_actual_tokens)
|
||||
return create_block_mask_compiled(
|
||||
mask_mod,
|
||||
None,
|
||||
@ -498,11 +543,7 @@ class FlexAttentionMetadata:
|
||||
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
|
||||
if self.causal:
|
||||
self.mask_mod = self.get_causal_mask_mod()
|
||||
else:
|
||||
self.mask_mod = self.get_bidirectional_mask_mod()
|
||||
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.transformed_score_mod = self.get_transformed_score_mod()
|
||||
|
||||
if self.direct_build and self.causal:
|
||||
@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder(
|
||||
|
||||
|
||||
class FlexAttentionImpl(AttentionImpl):
|
||||
sliding_window: Optional[tuple[int, int]]
|
||||
sliding_window: Optional[int]
|
||||
alibi_slopes: Optional[torch.Tensor]
|
||||
logits_soft_cap: Optional[float]
|
||||
|
||||
@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
"FlexAttention does not support alibi slopes yet.")
|
||||
else:
|
||||
self.alibi_slopes = None
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support sliding window yet.")
|
||||
else:
|
||||
self.sliding_window = (-1, -1)
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if self.logits_soft_cap is not None:
|
||||
@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if attn_metadata.sliding_window != self.sliding_window:
|
||||
attn_metadata.sliding_window = self.sliding_window
|
||||
if attn_metadata.direct_build:
|
||||
# TODO: Support skipping the computation of sliding window
|
||||
# in direct block mask building code path.
|
||||
logger.warning_once(
|
||||
"Using direct block mask building with sliding window, "
|
||||
"which is suboptimal now. Performance may be degraded.")
|
||||
# update mask mod in attention metadata
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
attn_metadata.block_mask = (
|
||||
attn_metadata._build_block_mask_direct())
|
||||
else:
|
||||
attn_metadata.block_mask = attn_metadata.build_block_mask()
|
||||
|
||||
if not attn_metadata.causal:
|
||||
assert self.attn_type == AttentionType.ENCODER_ONLY
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user