diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 0b7e103beca6..8a4fc15791b0 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -1,15 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 attention backends without GPUModelRunner dependency.""" +from functools import partial +from typing import Optional, Union import pytest import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) +from vllm.config import ModelConfig +from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, set_kv_cache_layout) @@ -183,13 +188,19 @@ class MockAttentionLayer: self._v_scale_float = 1.0 -def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, - layer_names: list[str], vllm_config, - device: torch.device, - common_attn_metadata: CommonAttentionMetadata, - query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor) -> torch.Tensor: +def run_attention_backend( + backend: _Backend, + kv_cache_spec: FullAttentionSpec, + layer_names: list[str], + vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + sliding_window: Optional[int] = None, +) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" # Handle special case for FLEX_ATTENTION_SLOW @@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, scale=scale, num_kv_heads=num_kv_heads, alibi_slopes=None, - sliding_window=None, + sliding_window=sliding_window, kv_cache_dtype="auto", ) @@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, return output -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium", "large_decode", "large_prefill", - "single_decode", "single_prefill" -]) -@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_backend_correctness(batch_spec_name: str, model: str): +def _test_backend_correctness( + batch_spec: BatchSpec, + model: str, + backend_to_test: list[Union[_Backend, str]], + mask_mod, + *, + block_size: int = 16, + atol: float = 1e-2, + rtol: float = 1e-2, +): """ Test that all backends produce similar outputs to a reference implementation using torch.nn.functional.scaled_dot_product_attention. @@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ - batch_spec = BATCH_SPECS[batch_spec_name] + current_platform.seed_everything(42) vllm_config = create_vllm_config(model_name=model, max_model_len=max(batch_spec.seq_lens), + block_size=block_size, num_gpu_blocks=8192) device = torch.device("cuda:0") @@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str): num_kv_heads = vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() + sliding_window = vllm_config.model_config.get_sliding_window() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size scale = 1.0 / (head_size**0.5) @@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str): # Create causal mask: query token i attends to positions 0 to # (context_len + i) kv_len = s_len - offset = context_len - attn_mask = torch.full((q_len, kv_len), - float('-inf'), - device=device, - dtype=dtype) - for i in range(q_len): - attn_mask[i, :offset + i + 1] = 0.0 - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - attn_mask=attn_mask, - scale=scale, - enable_gqa=True) - # Convert back to (L, H, D) + final_mask_mod = partial(mask_mod, context_len=context_len) + block_mask = create_block_mask(final_mask_mod, + B=None, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + device=device) + sdpa_out_i = flex_attention(q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + block_mask=block_mask, + scale=scale, + enable_gqa=True) + all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) # Inputs for vLLM backends are just the new tokens @@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str): # 4. Run vLLM backends and compare # Note: flex_attention has known Triton kernel compatibility issues # with test infrastructures - for backend_name in BACKENDS_TO_TEST: + for backend_name in backend_to_test: # FlashAttentionm + FlexAttention: # [2, num_blocks, block_size, num_kv_heads, head_size] # FlashInfer: @@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str): 2, 3).contiguous().transpose(2, 3) set_kv_cache_layout("HND") - backend_output = run_attention_backend(backend_name, kv_cache_spec, - ["placeholder"], vllm_config, - device, common_attn_metadata, - query_vllm, key_vllm, - value_vllm, - kv_cache_for_backend) + backend_output = run_attention_backend( + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + key_vllm, + value_vllm, + kv_cache_for_backend, + sliding_window=sliding_window, + ) # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( @@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str): f"[{backend_name}] produced non-finite values") # Check numerical similarity - rtol = 1e-2 - atol = 5e-3 + def error_msg(msg: str, backend_name: str): + return (f"[{backend_name}] output differs from SDPA baseline. " + f"{msg}") - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() - max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_output) / - torch.abs(sdpa_output)).item() - all_close = torch.allclose(backend_output, + torch.testing.assert_close(backend_output, sdpa_output, rtol=rtol, - atol=atol) + atol=atol, + msg=partial(error_msg, + backend_name=backend_name)) - assert all_close, ( - f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" +]) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_causal_backend_correctness(batch_spec_name: str, model: str): + """Test backend's correctness with causal attention.""" + + def causal_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + ): + return (q_idx + context_len) >= kv_idx + + batch_spec = BATCH_SPECS[batch_spec_name] + LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") else []) + SMALL_BLOCK_BACKENDS = [ + x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS + ] + _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, + causal_mask_mod) + + # Fast FlexAttention needs to run with block_size=128 + if LARGE_BLOCK_BACKENDS: + _test_backend_correctness(batch_spec, + model, + LARGE_BLOCK_BACKENDS, + causal_mask_mod, + block_size=128) + + +SLIDING_WINDOW_BACKENDS_TO_TEST = [ + _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW" +] + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_medium", "large_decode", + "large_prefill" +]) +@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) +def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): + """Test backend's correctness with sliding window attention.""" + + def sliding_window_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + sliding_window: int, + ): + causal_mask = q_idx + context_len >= kv_idx + window_mask = q_idx + context_len - kv_idx < sliding_window + return causal_mask & window_mask + + batch_spec = BATCH_SPECS[batch_spec_name] + model_config = ModelConfig(model=model, + max_model_len=max(batch_spec.seq_lens)) + sliding_window = model_config.get_sliding_window() + sliding_window_mask_mod_fn = partial(sliding_window_mask_mod, + sliding_window=sliding_window) + + LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") else []) + SMALL_BLOCK_BACKENDS = [ + x for x in SLIDING_WINDOW_BACKENDS_TO_TEST + if x not in LARGE_BLOCK_BACKENDS + ] + _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, + sliding_window_mask_mod_fn) + + # Fast FlexAttention needs to run with block_size=128 + if LARGE_BLOCK_BACKENDS: + _test_backend_correctness(batch_spec, + model, + LARGE_BLOCK_BACKENDS, + sliding_window_mask_mod_fn, + block_size=128) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 662d3984554a..c3358bfa74e9 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -9,7 +9,7 @@ import torch import torch._dynamo.decorators import torch.nn.functional as F from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - _score_mod_signature, + _score_mod_signature, and_masks, create_block_mask, flex_attention) @@ -292,6 +292,7 @@ class FlexAttentionMetadata: q_block_size: int = 16 kv_block_size: int = 16 transformed_score_mod: Optional[_score_mod_signature] = None + sliding_window: Optional[int] = None def _convert_physical_to_logical( self, @@ -380,6 +381,53 @@ class FlexAttentionMetadata: return final_mask_mod + def get_sliding_window_mask_mod(self) -> _mask_mod_signature: + """Creates the sliding window mask_mod function for FlexAttention. + + Note that the sliding window mask here is bidirectional, we need + to mask it with the bidirectional/causal mask for encoder/decoder. + """ + + if self.sliding_window is None: + raise ValueError( + "sliding_window must be set for sliding window attention") + + def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor, + q_idx: torch.Tensor, kv_idx: torch.Tensor): + return torch.abs(q_idx - kv_idx) < self.sliding_window + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + self.doc_ids, q_idx, physical_kv_idx) + return torch.where( + is_valid, + sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod if self.causal else sliding_window_mask_mod + + def get_mask_mod(self): + # Stage-1: initialize the base mask_mod + # (causal mask for decoder or bidirectional mask for encoder) + if self.causal: + mask_mod = self.get_causal_mask_mod() + else: + mask_mod = self.get_bidirectional_mask_mod() + # stage-2: add external mask_mod for special attention during + # forwarding runtime to create the combined mask_mod. + if self.sliding_window is not None: + # Add sliding window mask for sliding window attention + sliding_window_mask_mod = self.get_sliding_window_mask_mod() + mask_mod = and_masks(mask_mod, sliding_window_mask_mod) + return mask_mod + def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: """Creates the transformed score_mod function for FlexAttention. @@ -472,12 +520,9 @@ class FlexAttentionMetadata: return BlockMask.from_kv_blocks(**block_mask_kwargs) def build_block_mask(self) -> BlockMask: - if self.causal: - mask_mod = self.get_causal_mask_mod() - kv_len = self.total_cache_tokens - else: - mask_mod = self.get_bidirectional_mask_mod() - kv_len = self.num_actual_tokens + mask_mod = self.get_mask_mod() + kv_len = (self.total_cache_tokens + if self.causal else self.num_actual_tokens) return create_block_mask_compiled( mask_mod, None, @@ -498,11 +543,7 @@ class FlexAttentionMetadata: self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.num_blocks = self.total_cache_tokens // self.block_size - if self.causal: - self.mask_mod = self.get_causal_mask_mod() - else: - self.mask_mod = self.get_bidirectional_mask_mod() - + self.mask_mod = self.get_mask_mod() self.transformed_score_mod = self.get_transformed_score_mod() if self.direct_build and self.causal: @@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder( class FlexAttentionImpl(AttentionImpl): - sliding_window: Optional[tuple[int, int]] + sliding_window: Optional[int] alibi_slopes: Optional[torch.Tensor] logits_soft_cap: Optional[float] @@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl): "FlexAttention does not support alibi slopes yet.") else: self.alibi_slopes = None - if sliding_window is not None: - raise NotImplementedError( - "FlexAttention does not support sliding window yet.") - else: - self.sliding_window = (-1, -1) + + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap if self.logits_soft_cap is not None: @@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens + if attn_metadata.sliding_window != self.sliding_window: + attn_metadata.sliding_window = self.sliding_window + if attn_metadata.direct_build: + # TODO: Support skipping the computation of sliding window + # in direct block mask building code path. + logger.warning_once( + "Using direct block mask building with sliding window, " + "which is suboptimal now. Performance may be degraded.") + # update mask mod in attention metadata + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + attn_metadata.block_mask = ( + attn_metadata._build_block_mask_direct()) + else: + attn_metadata.block_mask = attn_metadata.build_block_mask() + if not attn_metadata.causal: assert self.attn_type == AttentionType.ENCODER_ONLY