mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 14:04:29 +08:00
[v1] Add real sliding window calculation to FlexAttention direct BlockMask building (#26015)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com> Co-authored-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
parent
ad9d656bfa
commit
b95db244ee
@ -74,6 +74,9 @@ BATCH_SPECS = {
|
||||
),
|
||||
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"mixed_large": BatchSpec(
|
||||
seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32]
|
||||
),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
}
|
||||
@ -587,7 +590,14 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"mixed_large",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
@ -315,6 +316,14 @@ class FlexAttentionMetadata:
|
||||
transformed_score_mod: _score_mod_signature | None = None
|
||||
sliding_window: int | None = None
|
||||
|
||||
@cached_property
|
||||
def logical_block_ids(self):
|
||||
return torch.arange(
|
||||
cdiv(self.max_seq_len, self.block_size),
|
||||
device=self.block_table.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def _convert_physical_to_logical(
|
||||
self,
|
||||
request_lookup: torch.Tensor,
|
||||
@ -493,6 +502,7 @@ class FlexAttentionMetadata:
|
||||
|
||||
The direct path works as follows:
|
||||
1. For each query token, fetch blocks from block_table using max_seq_len
|
||||
and exclude out of sliding window blocks if needed.
|
||||
(this fetches more blocks than needed for shorter sequences)
|
||||
2. Group query tokens into chunks of q_block_size
|
||||
3. For each group, deduplicate the blocks using unique_static_unsorted
|
||||
@ -517,6 +527,23 @@ class FlexAttentionMetadata:
|
||||
used_pages = self.block_table[
|
||||
self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
|
||||
]
|
||||
|
||||
if self.sliding_window and self.causal:
|
||||
device = used_pages.device
|
||||
assert self.doc_ids is not None
|
||||
token_indices = torch.arange(
|
||||
self.doc_ids.shape[0], device=device, dtype=torch.long
|
||||
)
|
||||
logical_q_idx = (
|
||||
token_indices
|
||||
- self.query_start_loc[self.doc_ids]
|
||||
+ self.decode_offset[self.doc_ids]
|
||||
)
|
||||
min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0)
|
||||
min_block_idx = min_kv_idx // self.block_size
|
||||
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
|
||||
used_pages.masked_fill_(~sliding_mask, 0)
|
||||
|
||||
used_pages_padded = pad_to_multiple(
|
||||
used_pages, multiple=self.q_block_size, dim=0
|
||||
)
|
||||
@ -785,12 +812,6 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
if attn_metadata.sliding_window != self.sliding_window:
|
||||
attn_metadata.sliding_window = self.sliding_window
|
||||
if attn_metadata.direct_build:
|
||||
# TODO: Support skipping the computation of sliding window
|
||||
# in direct block mask building code path.
|
||||
logger.warning_once(
|
||||
"Using direct block mask building with sliding window, "
|
||||
"which is suboptimal now. Performance may be degraded."
|
||||
)
|
||||
# update mask mod in attention metadata
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user