[V1] Add sliding window support to Flex Attention backend (#24089)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-09-21 13:08:07 +08:00 committed by GitHub
parent 7ed82d1974
commit cf56cf78b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 227 additions and 67 deletions

View File

@ -1,15 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 attention backends without GPUModelRunner dependency."""
from functools import partial
from typing import Optional, Union
import pytest
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
set_kv_cache_layout)
@ -183,13 +188,19 @@ class MockAttentionLayer:
self._v_scale_float = 1.0
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
layer_names: list[str], vllm_config,
device: torch.device,
common_attn_metadata: CommonAttentionMetadata,
query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor) -> torch.Tensor:
def run_attention_backend(
backend: _Backend,
kv_cache_spec: FullAttentionSpec,
layer_names: list[str],
vllm_config,
device: torch.device,
common_attn_metadata: CommonAttentionMetadata,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
sliding_window: Optional[int] = None,
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
# Handle special case for FLEX_ATTENTION_SLOW
@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
sliding_window=sliding_window,
kv_cache_dtype="auto",
)
@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
return output
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_backend_correctness(batch_spec_name: str, model: str):
def _test_backend_correctness(
batch_spec: BatchSpec,
model: str,
backend_to_test: list[Union[_Backend, str]],
mask_mod,
*,
block_size: int = 16,
atol: float = 1e-2,
rtol: float = 1e-2,
):
"""
Test that all backends produce similar outputs to a reference implementation
using torch.nn.functional.scaled_dot_product_attention.
@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
batch_spec = BATCH_SPECS[batch_spec_name]
current_platform.seed_everything(42)
vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens),
block_size=block_size,
num_gpu_blocks=8192)
device = torch.device("cuda:0")
@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
sliding_window = vllm_config.model_config.get_sliding_window()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
scale = 1.0 / (head_size**0.5)
@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
# Create causal mask: query token i attends to positions 0 to
# (context_len + i)
kv_len = s_len
offset = context_len
attn_mask = torch.full((q_len, kv_len),
float('-inf'),
device=device,
dtype=dtype)
for i in range(q_len):
attn_mask[i, :offset + i + 1] = 0.0
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
attn_mask=attn_mask,
scale=scale,
enable_gqa=True)
# Convert back to (L, H, D)
final_mask_mod = partial(mask_mod, context_len=context_len)
block_mask = create_block_mask(final_mask_mod,
B=None,
H=None,
Q_LEN=q_len,
KV_LEN=kv_len,
device=device)
sdpa_out_i = flex_attention(q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
block_mask=block_mask,
scale=scale,
enable_gqa=True)
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
# Inputs for vLLM backends are just the new tokens
@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
# 4. Run vLLM backends and compare
# Note: flex_attention has known Triton kernel compatibility issues
# with test infrastructures
for backend_name in BACKENDS_TO_TEST:
for backend_name in backend_to_test:
# FlashAttentionm + FlexAttention:
# [2, num_blocks, block_size, num_kv_heads, head_size]
# FlashInfer:
@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
2, 3).contiguous().transpose(2, 3)
set_kv_cache_layout("HND")
backend_output = run_attention_backend(backend_name, kv_cache_spec,
["placeholder"], vllm_config,
device, common_attn_metadata,
query_vllm, key_vllm,
value_vllm,
kv_cache_for_backend)
backend_output = run_attention_backend(
backend_name,
kv_cache_spec,
["placeholder"],
vllm_config,
device,
common_attn_metadata,
query_vllm,
key_vllm,
value_vllm,
kv_cache_for_backend,
sliding_window=sliding_window,
)
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
f"[{backend_name}] produced non-finite values")
# Check numerical similarity
rtol = 1e-2
atol = 5e-3
def error_msg(msg: str, backend_name: str):
return (f"[{backend_name}] output differs from SDPA baseline. "
f"{msg}")
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
torch.abs(sdpa_output)).item()
all_close = torch.allclose(backend_output,
torch.testing.assert_close(backend_output,
sdpa_output,
rtol=rtol,
atol=atol)
atol=atol,
msg=partial(error_msg,
backend_name=backend_name))
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_causal_backend_correctness(batch_spec_name: str, model: str):
"""Test backend's correctness with causal attention."""
def causal_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
):
return (q_idx + context_len) >= kv_idx
batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0") else [])
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
causal_mask_mod)
# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(batch_spec,
model,
LARGE_BLOCK_BACKENDS,
causal_mask_mod,
block_size=128)
SLIDING_WINDOW_BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION,
_Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW"
]
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_medium", "large_decode",
"large_prefill"
])
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
"""Test backend's correctness with sliding window attention."""
def sliding_window_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
sliding_window: int,
):
causal_mask = q_idx + context_len >= kv_idx
window_mask = q_idx + context_len - kv_idx < sliding_window
return causal_mask & window_mask
batch_spec = BATCH_SPECS[batch_spec_name]
model_config = ModelConfig(model=model,
max_model_len=max(batch_spec.seq_lens))
sliding_window = model_config.get_sliding_window()
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
sliding_window=sliding_window)
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0") else [])
SMALL_BLOCK_BACKENDS = [
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
sliding_window_mask_mod_fn)
# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(batch_spec,
model,
LARGE_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
block_size=128)

View File

@ -9,7 +9,7 @@ import torch
import torch._dynamo.decorators
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
_score_mod_signature,
_score_mod_signature, and_masks,
create_block_mask,
flex_attention)
@ -292,6 +292,7 @@ class FlexAttentionMetadata:
q_block_size: int = 16
kv_block_size: int = 16
transformed_score_mod: Optional[_score_mod_signature] = None
sliding_window: Optional[int] = None
def _convert_physical_to_logical(
self,
@ -380,6 +381,53 @@ class FlexAttentionMetadata:
return final_mask_mod
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
"""Creates the sliding window mask_mod function for FlexAttention.
Note that the sliding window mask here is bidirectional, we need
to mask it with the bidirectional/causal mask for encoder/decoder.
"""
if self.sliding_window is None:
raise ValueError(
"sliding_window must be set for sliding window attention")
def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor,
q_idx: torch.Tensor, kv_idx: torch.Tensor):
return torch.abs(q_idx - kv_idx) < self.sliding_window
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
self.doc_ids, q_idx, physical_kv_idx)
return torch.where(
is_valid,
sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod if self.causal else sliding_window_mask_mod
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
if self.causal:
mask_mod = self.get_causal_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()
# stage-2: add external mask_mod for special attention during
# forwarding runtime to create the combined mask_mod.
if self.sliding_window is not None:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
return mask_mod
def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
"""Creates the transformed score_mod function for FlexAttention.
@ -472,12 +520,9 @@ class FlexAttentionMetadata:
return BlockMask.from_kv_blocks(**block_mask_kwargs)
def build_block_mask(self) -> BlockMask:
if self.causal:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
else:
mask_mod = self.get_bidirectional_mask_mod()
kv_len = self.num_actual_tokens
mask_mod = self.get_mask_mod()
kv_len = (self.total_cache_tokens
if self.causal else self.num_actual_tokens)
return create_block_mask_compiled(
mask_mod,
None,
@ -498,11 +543,7 @@ class FlexAttentionMetadata:
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.num_blocks = self.total_cache_tokens // self.block_size
if self.causal:
self.mask_mod = self.get_causal_mask_mod()
else:
self.mask_mod = self.get_bidirectional_mask_mod()
self.mask_mod = self.get_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod()
if self.direct_build and self.causal:
@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder(
class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[tuple[int, int]]
sliding_window: Optional[int]
alibi_slopes: Optional[torch.Tensor]
logits_soft_cap: Optional[float]
@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl):
"FlexAttention does not support alibi slopes yet.")
else:
self.alibi_slopes = None
if sliding_window is not None:
raise NotImplementedError(
"FlexAttention does not support sliding window yet.")
else:
self.sliding_window = (-1, -1)
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
if self.logits_soft_cap is not None:
@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# TODO: Support skipping the computation of sliding window
# in direct block mask building code path.
logger.warning_once(
"Using direct block mask building with sliding window, "
"which is suboptimal now. Performance may be degraded.")
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
attn_metadata.block_mask = (
attn_metadata._build_block_mask_direct())
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()
if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY