[V1] Add sliding window support to Flex Attention backend (#24089)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-09-21 13:08:07 +08:00 committed by GitHub
parent 7ed82d1974
commit cf56cf78b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 227 additions and 67 deletions

View File

@ -1,15 +1,20 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 attention backends without GPUModelRunner dependency.""" """Tests for v1 attention backends without GPUModelRunner dependency."""
from functools import partial
from typing import Optional, Union
import pytest import pytest
import torch import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata, create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
set_kv_cache_layout) set_kv_cache_layout)
@ -183,13 +188,19 @@ class MockAttentionLayer:
self._v_scale_float = 1.0 self._v_scale_float = 1.0
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, def run_attention_backend(
layer_names: list[str], vllm_config, backend: _Backend,
device: torch.device, kv_cache_spec: FullAttentionSpec,
common_attn_metadata: CommonAttentionMetadata, layer_names: list[str],
query: torch.Tensor, key: torch.Tensor, vllm_config,
value: torch.Tensor, device: torch.device,
kv_cache: torch.Tensor) -> torch.Tensor: common_attn_metadata: CommonAttentionMetadata,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
sliding_window: Optional[int] = None,
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl.""" """Run attention computation using the specified backend's AttentionImpl."""
# Handle special case for FLEX_ATTENTION_SLOW # Handle special case for FLEX_ATTENTION_SLOW
@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
scale=scale, scale=scale,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
alibi_slopes=None, alibi_slopes=None,
sliding_window=None, sliding_window=sliding_window,
kv_cache_dtype="auto", kv_cache_dtype="auto",
) )
@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
return output return output
@pytest.mark.parametrize("batch_spec_name", [ def _test_backend_correctness(
"small_decode", "small_prefill", "mixed_small", "medium_decode", batch_spec: BatchSpec,
"medium_prefill", "mixed_medium", "large_decode", "large_prefill", model: str,
"single_decode", "single_prefill" backend_to_test: list[Union[_Backend, str]],
]) mask_mod,
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) *,
def test_backend_correctness(batch_spec_name: str, model: str): block_size: int = 16,
atol: float = 1e-2,
rtol: float = 1e-2,
):
""" """
Test that all backends produce similar outputs to a reference implementation Test that all backends produce similar outputs to a reference implementation
using torch.nn.functional.scaled_dot_product_attention. using torch.nn.functional.scaled_dot_product_attention.
@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
simulated paged KV cache. simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output. 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
""" """
batch_spec = BATCH_SPECS[batch_spec_name] current_platform.seed_everything(42)
vllm_config = create_vllm_config(model_name=model, vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens), max_model_len=max(batch_spec.seq_lens),
block_size=block_size,
num_gpu_blocks=8192) num_gpu_blocks=8192)
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
num_kv_heads = vllm_config.model_config.get_num_kv_heads( num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config) vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size() head_size = vllm_config.model_config.get_head_size()
sliding_window = vllm_config.model_config.get_sliding_window()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
scale = 1.0 / (head_size**0.5) scale = 1.0 / (head_size**0.5)
@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
# Create causal mask: query token i attends to positions 0 to # Create causal mask: query token i attends to positions 0 to
# (context_len + i) # (context_len + i)
kv_len = s_len kv_len = s_len
offset = context_len
attn_mask = torch.full((q_len, kv_len),
float('-inf'),
device=device,
dtype=dtype)
for i in range(q_len):
attn_mask[i, :offset + i + 1] = 0.0
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( final_mask_mod = partial(mask_mod, context_len=context_len)
q_sdpa_in, block_mask = create_block_mask(final_mask_mod,
k_sdpa_in, B=None,
v_sdpa_in, H=None,
attn_mask=attn_mask, Q_LEN=q_len,
scale=scale, KV_LEN=kv_len,
enable_gqa=True) device=device)
# Convert back to (L, H, D) sdpa_out_i = flex_attention(q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
block_mask=block_mask,
scale=scale,
enable_gqa=True)
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
# Inputs for vLLM backends are just the new tokens # Inputs for vLLM backends are just the new tokens
@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
# 4. Run vLLM backends and compare # 4. Run vLLM backends and compare
# Note: flex_attention has known Triton kernel compatibility issues # Note: flex_attention has known Triton kernel compatibility issues
# with test infrastructures # with test infrastructures
for backend_name in BACKENDS_TO_TEST: for backend_name in backend_to_test:
# FlashAttentionm + FlexAttention: # FlashAttentionm + FlexAttention:
# [2, num_blocks, block_size, num_kv_heads, head_size] # [2, num_blocks, block_size, num_kv_heads, head_size]
# FlashInfer: # FlashInfer:
@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
2, 3).contiguous().transpose(2, 3) 2, 3).contiguous().transpose(2, 3)
set_kv_cache_layout("HND") set_kv_cache_layout("HND")
backend_output = run_attention_backend(backend_name, kv_cache_spec, backend_output = run_attention_backend(
["placeholder"], vllm_config, backend_name,
device, common_attn_metadata, kv_cache_spec,
query_vllm, key_vllm, ["placeholder"],
value_vllm, vllm_config,
kv_cache_for_backend) device,
common_attn_metadata,
query_vllm,
key_vllm,
value_vllm,
kv_cache_for_backend,
sliding_window=sliding_window,
)
# Check shape and dtype consistency # Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, ( assert backend_output.shape == sdpa_output.shape, (
@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
f"[{backend_name}] produced non-finite values") f"[{backend_name}] produced non-finite values")
# Check numerical similarity # Check numerical similarity
rtol = 1e-2 def error_msg(msg: str, backend_name: str):
atol = 5e-3 return (f"[{backend_name}] output differs from SDPA baseline. "
f"{msg}")
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() torch.testing.assert_close(backend_output,
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
torch.abs(sdpa_output)).item()
all_close = torch.allclose(backend_output,
sdpa_output, sdpa_output,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol,
msg=partial(error_msg,
backend_name=backend_name))
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. " @pytest.mark.parametrize("batch_spec_name", [
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") "small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_causal_backend_correctness(batch_spec_name: str, model: str):
"""Test backend's correctness with causal attention."""
def causal_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
):
return (q_idx + context_len) >= kv_idx
batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0") else [])
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
causal_mask_mod)
# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(batch_spec,
model,
LARGE_BLOCK_BACKENDS,
causal_mask_mod,
block_size=128)
SLIDING_WINDOW_BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION,
_Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW"
]
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_medium", "large_decode",
"large_prefill"
])
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
"""Test backend's correctness with sliding window attention."""
def sliding_window_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
sliding_window: int,
):
causal_mask = q_idx + context_len >= kv_idx
window_mask = q_idx + context_len - kv_idx < sliding_window
return causal_mask & window_mask
batch_spec = BATCH_SPECS[batch_spec_name]
model_config = ModelConfig(model=model,
max_model_len=max(batch_spec.seq_lens))
sliding_window = model_config.get_sliding_window()
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
sliding_window=sliding_window)
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0") else [])
SMALL_BLOCK_BACKENDS = [
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
sliding_window_mask_mod_fn)
# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(batch_spec,
model,
LARGE_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
block_size=128)

View File

@ -9,7 +9,7 @@ import torch
import torch._dynamo.decorators import torch._dynamo.decorators
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
_score_mod_signature, _score_mod_signature, and_masks,
create_block_mask, create_block_mask,
flex_attention) flex_attention)
@ -292,6 +292,7 @@ class FlexAttentionMetadata:
q_block_size: int = 16 q_block_size: int = 16
kv_block_size: int = 16 kv_block_size: int = 16
transformed_score_mod: Optional[_score_mod_signature] = None transformed_score_mod: Optional[_score_mod_signature] = None
sliding_window: Optional[int] = None
def _convert_physical_to_logical( def _convert_physical_to_logical(
self, self,
@ -380,6 +381,53 @@ class FlexAttentionMetadata:
return final_mask_mod return final_mask_mod
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
"""Creates the sliding window mask_mod function for FlexAttention.
Note that the sliding window mask here is bidirectional, we need
to mask it with the bidirectional/causal mask for encoder/decoder.
"""
if self.sliding_window is None:
raise ValueError(
"sliding_window must be set for sliding window attention")
def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor,
q_idx: torch.Tensor, kv_idx: torch.Tensor):
return torch.abs(q_idx - kv_idx) < self.sliding_window
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
self.doc_ids, q_idx, physical_kv_idx)
return torch.where(
is_valid,
sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod if self.causal else sliding_window_mask_mod
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
if self.causal:
mask_mod = self.get_causal_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()
# stage-2: add external mask_mod for special attention during
# forwarding runtime to create the combined mask_mod.
if self.sliding_window is not None:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
return mask_mod
def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
"""Creates the transformed score_mod function for FlexAttention. """Creates the transformed score_mod function for FlexAttention.
@ -472,12 +520,9 @@ class FlexAttentionMetadata:
return BlockMask.from_kv_blocks(**block_mask_kwargs) return BlockMask.from_kv_blocks(**block_mask_kwargs)
def build_block_mask(self) -> BlockMask: def build_block_mask(self) -> BlockMask:
if self.causal: mask_mod = self.get_mask_mod()
mask_mod = self.get_causal_mask_mod() kv_len = (self.total_cache_tokens
kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens)
else:
mask_mod = self.get_bidirectional_mask_mod()
kv_len = self.num_actual_tokens
return create_block_mask_compiled( return create_block_mask_compiled(
mask_mod, mask_mod,
None, None,
@ -498,11 +543,7 @@ class FlexAttentionMetadata:
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.num_blocks = self.total_cache_tokens // self.block_size self.num_blocks = self.total_cache_tokens // self.block_size
if self.causal: self.mask_mod = self.get_mask_mod()
self.mask_mod = self.get_causal_mask_mod()
else:
self.mask_mod = self.get_bidirectional_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod() self.transformed_score_mod = self.get_transformed_score_mod()
if self.direct_build and self.causal: if self.direct_build and self.causal:
@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder(
class FlexAttentionImpl(AttentionImpl): class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[tuple[int, int]] sliding_window: Optional[int]
alibi_slopes: Optional[torch.Tensor] alibi_slopes: Optional[torch.Tensor]
logits_soft_cap: Optional[float] logits_soft_cap: Optional[float]
@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl):
"FlexAttention does not support alibi slopes yet.") "FlexAttention does not support alibi slopes yet.")
else: else:
self.alibi_slopes = None self.alibi_slopes = None
if sliding_window is not None:
raise NotImplementedError( self.sliding_window = sliding_window
"FlexAttention does not support sliding window yet.")
else:
self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
if self.logits_soft_cap is not None: if self.logits_soft_cap is not None:
@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# TODO: Support skipping the computation of sliding window
# in direct block mask building code path.
logger.warning_once(
"Using direct block mask building with sliding window, "
"which is suboptimal now. Performance may be degraded.")
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
attn_metadata.block_mask = (
attn_metadata._build_block_mask_direct())
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()
if not attn_metadata.causal: if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY assert self.attn_type == AttentionType.ENCODER_ONLY