[v1] Add real sliding window calculation to FlexAttention direct BlockMask building (#26015)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Co-authored-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
Isotr0py 2025-12-01 21:12:51 +08:00 committed by GitHub
parent ad9d656bfa
commit b95db244ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 7 deletions

View File

@ -74,6 +74,9 @@ BATCH_SPECS = {
),
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
"mixed_large": BatchSpec(
seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32]
),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
}
@ -587,7 +590,14 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [
@pytest.mark.parametrize(
"batch_spec_name",
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
[
"small_decode",
"small_prefill",
"mixed_medium",
"large_decode",
"large_prefill",
"mixed_large",
],
)
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])

View File

@ -4,6 +4,7 @@
import math
from dataclasses import dataclass
from functools import cached_property
from typing import ClassVar
import torch
@ -315,6 +316,14 @@ class FlexAttentionMetadata:
transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None
@cached_property
def logical_block_ids(self):
return torch.arange(
cdiv(self.max_seq_len, self.block_size),
device=self.block_table.device,
dtype=torch.long,
)
def _convert_physical_to_logical(
self,
request_lookup: torch.Tensor,
@ -493,6 +502,7 @@ class FlexAttentionMetadata:
The direct path works as follows:
1. For each query token, fetch blocks from block_table using max_seq_len
and exclude out of sliding window blocks if needed.
(this fetches more blocks than needed for shorter sequences)
2. Group query tokens into chunks of q_block_size
3. For each group, deduplicate the blocks using unique_static_unsorted
@ -517,6 +527,23 @@ class FlexAttentionMetadata:
used_pages = self.block_table[
self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
]
if self.sliding_window and self.causal:
device = used_pages.device
assert self.doc_ids is not None
token_indices = torch.arange(
self.doc_ids.shape[0], device=device, dtype=torch.long
)
logical_q_idx = (
token_indices
- self.query_start_loc[self.doc_ids]
+ self.decode_offset[self.doc_ids]
)
min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0)
min_block_idx = min_kv_idx // self.block_size
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
used_pages.masked_fill_(~sliding_mask, 0)
used_pages_padded = pad_to_multiple(
used_pages, multiple=self.q_block_size, dim=0
)
@ -785,12 +812,6 @@ class FlexAttentionImpl(AttentionImpl):
if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# TODO: Support skipping the computation of sliding window
# in direct block mask building code path.
logger.warning_once(
"Using direct block mask building with sliding window, "
"which is suboptimal now. Performance may be degraded."
)
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()