diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index b46002c5fa8ff..e7ec8380e0a84 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -74,6 +74,9 @@ BATCH_SPECS = { ), "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "mixed_large": BatchSpec( + seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32] + ), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), } @@ -587,7 +590,14 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [ @pytest.mark.parametrize( "batch_spec_name", - ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], + [ + "small_decode", + "small_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "mixed_large", + ], ) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 8de0a0a11471f..fe92f6570501c 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -4,6 +4,7 @@ import math from dataclasses import dataclass +from functools import cached_property from typing import ClassVar import torch @@ -315,6 +316,14 @@ class FlexAttentionMetadata: transformed_score_mod: _score_mod_signature | None = None sliding_window: int | None = None + @cached_property + def logical_block_ids(self): + return torch.arange( + cdiv(self.max_seq_len, self.block_size), + device=self.block_table.device, + dtype=torch.long, + ) + def _convert_physical_to_logical( self, request_lookup: torch.Tensor, @@ -493,6 +502,7 @@ class FlexAttentionMetadata: The direct path works as follows: 1. For each query token, fetch blocks from block_table using max_seq_len + and exclude out of sliding window blocks if needed. (this fetches more blocks than needed for shorter sequences) 2. Group query tokens into chunks of q_block_size 3. For each group, deduplicate the blocks using unique_static_unsorted @@ -517,6 +527,23 @@ class FlexAttentionMetadata: used_pages = self.block_table[ self.doc_ids, : cdiv(self.max_seq_len, self.block_size) ] + + if self.sliding_window and self.causal: + device = used_pages.device + assert self.doc_ids is not None + token_indices = torch.arange( + self.doc_ids.shape[0], device=device, dtype=torch.long + ) + logical_q_idx = ( + token_indices + - self.query_start_loc[self.doc_ids] + + self.decode_offset[self.doc_ids] + ) + min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0) + min_block_idx = min_kv_idx // self.block_size + sliding_mask = self.logical_block_ids >= min_block_idx[:, None] + used_pages.masked_fill_(~sliding_mask, 0) + used_pages_padded = pad_to_multiple( used_pages, multiple=self.q_block_size, dim=0 ) @@ -785,12 +812,6 @@ class FlexAttentionImpl(AttentionImpl): if attn_metadata.sliding_window != self.sliding_window: attn_metadata.sliding_window = self.sliding_window if attn_metadata.direct_build: - # TODO: Support skipping the computation of sliding window - # in direct block mask building code path. - logger.warning_once( - "Using direct block mask building with sliding window, " - "which is suboptimal now. Performance may be degraded." - ) # update mask mod in attention metadata attn_metadata.mask_mod = attn_metadata.get_mask_mod() attn_metadata.block_mask = attn_metadata._build_block_mask_direct()