mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:24:57 +08:00
[V1] Add sliding window support to Flex Attention backend (#24089)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
7ed82d1974
commit
cf56cf78b4
@ -1,15 +1,20 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Tests for v1 attention backends without GPUModelRunner dependency."""
|
"""Tests for v1 attention backends without GPUModelRunner dependency."""
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||||
|
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||||
create_common_attn_metadata,
|
create_common_attn_metadata,
|
||||||
create_standard_kv_cache_spec,
|
create_standard_kv_cache_spec,
|
||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
get_attention_backend)
|
get_attention_backend)
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
set_kv_cache_layout)
|
set_kv_cache_layout)
|
||||||
@ -183,13 +188,19 @@ class MockAttentionLayer:
|
|||||||
self._v_scale_float = 1.0
|
self._v_scale_float = 1.0
|
||||||
|
|
||||||
|
|
||||||
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
def run_attention_backend(
|
||||||
layer_names: list[str], vllm_config,
|
backend: _Backend,
|
||||||
device: torch.device,
|
kv_cache_spec: FullAttentionSpec,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
layer_names: list[str],
|
||||||
query: torch.Tensor, key: torch.Tensor,
|
vllm_config,
|
||||||
value: torch.Tensor,
|
device: torch.device,
|
||||||
kv_cache: torch.Tensor) -> torch.Tensor:
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||||
|
|
||||||
# Handle special case for FLEX_ATTENTION_SLOW
|
# Handle special case for FLEX_ATTENTION_SLOW
|
||||||
@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None,
|
sliding_window=sliding_window,
|
||||||
kv_cache_dtype="auto",
|
kv_cache_dtype="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_spec_name", [
|
def _test_backend_correctness(
|
||||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
batch_spec: BatchSpec,
|
||||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
model: str,
|
||||||
"single_decode", "single_prefill"
|
backend_to_test: list[Union[_Backend, str]],
|
||||||
])
|
mask_mod,
|
||||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
*,
|
||||||
def test_backend_correctness(batch_spec_name: str, model: str):
|
block_size: int = 16,
|
||||||
|
atol: float = 1e-2,
|
||||||
|
rtol: float = 1e-2,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Test that all backends produce similar outputs to a reference implementation
|
Test that all backends produce similar outputs to a reference implementation
|
||||||
using torch.nn.functional.scaled_dot_product_attention.
|
using torch.nn.functional.scaled_dot_product_attention.
|
||||||
@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
simulated paged KV cache.
|
simulated paged KV cache.
|
||||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||||
"""
|
"""
|
||||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
current_platform.seed_everything(42)
|
||||||
vllm_config = create_vllm_config(model_name=model,
|
vllm_config = create_vllm_config(model_name=model,
|
||||||
max_model_len=max(batch_spec.seq_lens),
|
max_model_len=max(batch_spec.seq_lens),
|
||||||
|
block_size=block_size,
|
||||||
num_gpu_blocks=8192)
|
num_gpu_blocks=8192)
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
head_size = vllm_config.model_config.get_head_size()
|
head_size = vllm_config.model_config.get_head_size()
|
||||||
|
sliding_window = vllm_config.model_config.get_sliding_window()
|
||||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||||
block_size = vllm_config.cache_config.block_size
|
block_size = vllm_config.cache_config.block_size
|
||||||
scale = 1.0 / (head_size**0.5)
|
scale = 1.0 / (head_size**0.5)
|
||||||
@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
# Create causal mask: query token i attends to positions 0 to
|
# Create causal mask: query token i attends to positions 0 to
|
||||||
# (context_len + i)
|
# (context_len + i)
|
||||||
kv_len = s_len
|
kv_len = s_len
|
||||||
offset = context_len
|
|
||||||
attn_mask = torch.full((q_len, kv_len),
|
|
||||||
float('-inf'),
|
|
||||||
device=device,
|
|
||||||
dtype=dtype)
|
|
||||||
for i in range(q_len):
|
|
||||||
attn_mask[i, :offset + i + 1] = 0.0
|
|
||||||
|
|
||||||
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
|
final_mask_mod = partial(mask_mod, context_len=context_len)
|
||||||
q_sdpa_in,
|
block_mask = create_block_mask(final_mask_mod,
|
||||||
k_sdpa_in,
|
B=None,
|
||||||
v_sdpa_in,
|
H=None,
|
||||||
attn_mask=attn_mask,
|
Q_LEN=q_len,
|
||||||
scale=scale,
|
KV_LEN=kv_len,
|
||||||
enable_gqa=True)
|
device=device)
|
||||||
# Convert back to (L, H, D)
|
sdpa_out_i = flex_attention(q_sdpa_in,
|
||||||
|
k_sdpa_in,
|
||||||
|
v_sdpa_in,
|
||||||
|
block_mask=block_mask,
|
||||||
|
scale=scale,
|
||||||
|
enable_gqa=True)
|
||||||
|
|
||||||
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
|
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
|
||||||
|
|
||||||
# Inputs for vLLM backends are just the new tokens
|
# Inputs for vLLM backends are just the new tokens
|
||||||
@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
# 4. Run vLLM backends and compare
|
# 4. Run vLLM backends and compare
|
||||||
# Note: flex_attention has known Triton kernel compatibility issues
|
# Note: flex_attention has known Triton kernel compatibility issues
|
||||||
# with test infrastructures
|
# with test infrastructures
|
||||||
for backend_name in BACKENDS_TO_TEST:
|
for backend_name in backend_to_test:
|
||||||
# FlashAttentionm + FlexAttention:
|
# FlashAttentionm + FlexAttention:
|
||||||
# [2, num_blocks, block_size, num_kv_heads, head_size]
|
# [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
# FlashInfer:
|
# FlashInfer:
|
||||||
@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
2, 3).contiguous().transpose(2, 3)
|
2, 3).contiguous().transpose(2, 3)
|
||||||
set_kv_cache_layout("HND")
|
set_kv_cache_layout("HND")
|
||||||
|
|
||||||
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
backend_output = run_attention_backend(
|
||||||
["placeholder"], vllm_config,
|
backend_name,
|
||||||
device, common_attn_metadata,
|
kv_cache_spec,
|
||||||
query_vllm, key_vllm,
|
["placeholder"],
|
||||||
value_vllm,
|
vllm_config,
|
||||||
kv_cache_for_backend)
|
device,
|
||||||
|
common_attn_metadata,
|
||||||
|
query_vllm,
|
||||||
|
key_vllm,
|
||||||
|
value_vllm,
|
||||||
|
kv_cache_for_backend,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
# Check shape and dtype consistency
|
# Check shape and dtype consistency
|
||||||
assert backend_output.shape == sdpa_output.shape, (
|
assert backend_output.shape == sdpa_output.shape, (
|
||||||
@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
f"[{backend_name}] produced non-finite values")
|
f"[{backend_name}] produced non-finite values")
|
||||||
|
|
||||||
# Check numerical similarity
|
# Check numerical similarity
|
||||||
rtol = 1e-2
|
def error_msg(msg: str, backend_name: str):
|
||||||
atol = 5e-3
|
return (f"[{backend_name}] output differs from SDPA baseline. "
|
||||||
|
f"{msg}")
|
||||||
|
|
||||||
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
|
torch.testing.assert_close(backend_output,
|
||||||
max_rel_diff = torch.max(
|
|
||||||
torch.abs(backend_output - sdpa_output) /
|
|
||||||
torch.abs(sdpa_output)).item()
|
|
||||||
all_close = torch.allclose(backend_output,
|
|
||||||
sdpa_output,
|
sdpa_output,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol,
|
||||||
|
msg=partial(error_msg,
|
||||||
|
backend_name=backend_name))
|
||||||
|
|
||||||
assert all_close, (
|
|
||||||
f"[{backend_name}] output differs from SDPA baseline. "
|
@pytest.mark.parametrize("batch_spec_name", [
|
||||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
|
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||||
|
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||||
|
"single_decode", "single_prefill"
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||||
|
def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||||
|
"""Test backend's correctness with causal attention."""
|
||||||
|
|
||||||
|
def causal_mask_mod(
|
||||||
|
b: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
q_idx: torch.Tensor,
|
||||||
|
kv_idx: torch.Tensor,
|
||||||
|
*,
|
||||||
|
context_len: int,
|
||||||
|
):
|
||||||
|
return (q_idx + context_len) >= kv_idx
|
||||||
|
|
||||||
|
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||||
|
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||||
|
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||||
|
SMALL_BLOCK_BACKENDS = [
|
||||||
|
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||||
|
]
|
||||||
|
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||||
|
causal_mask_mod)
|
||||||
|
|
||||||
|
# Fast FlexAttention needs to run with block_size=128
|
||||||
|
if LARGE_BLOCK_BACKENDS:
|
||||||
|
_test_backend_correctness(batch_spec,
|
||||||
|
model,
|
||||||
|
LARGE_BLOCK_BACKENDS,
|
||||||
|
causal_mask_mod,
|
||||||
|
block_size=128)
|
||||||
|
|
||||||
|
|
||||||
|
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||||
|
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION,
|
||||||
|
_Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_spec_name", [
|
||||||
|
"small_decode", "small_prefill", "mixed_medium", "large_decode",
|
||||||
|
"large_prefill"
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
|
||||||
|
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||||
|
"""Test backend's correctness with sliding window attention."""
|
||||||
|
|
||||||
|
def sliding_window_mask_mod(
|
||||||
|
b: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
q_idx: torch.Tensor,
|
||||||
|
kv_idx: torch.Tensor,
|
||||||
|
*,
|
||||||
|
context_len: int,
|
||||||
|
sliding_window: int,
|
||||||
|
):
|
||||||
|
causal_mask = q_idx + context_len >= kv_idx
|
||||||
|
window_mask = q_idx + context_len - kv_idx < sliding_window
|
||||||
|
return causal_mask & window_mask
|
||||||
|
|
||||||
|
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||||
|
model_config = ModelConfig(model=model,
|
||||||
|
max_model_len=max(batch_spec.seq_lens))
|
||||||
|
sliding_window = model_config.get_sliding_window()
|
||||||
|
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
|
||||||
|
sliding_window=sliding_window)
|
||||||
|
|
||||||
|
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||||
|
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||||
|
SMALL_BLOCK_BACKENDS = [
|
||||||
|
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
|
||||||
|
if x not in LARGE_BLOCK_BACKENDS
|
||||||
|
]
|
||||||
|
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||||
|
sliding_window_mask_mod_fn)
|
||||||
|
|
||||||
|
# Fast FlexAttention needs to run with block_size=128
|
||||||
|
if LARGE_BLOCK_BACKENDS:
|
||||||
|
_test_backend_correctness(batch_spec,
|
||||||
|
model,
|
||||||
|
LARGE_BLOCK_BACKENDS,
|
||||||
|
sliding_window_mask_mod_fn,
|
||||||
|
block_size=128)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import torch
|
|||||||
import torch._dynamo.decorators
|
import torch._dynamo.decorators
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||||
_score_mod_signature,
|
_score_mod_signature, and_masks,
|
||||||
create_block_mask,
|
create_block_mask,
|
||||||
flex_attention)
|
flex_attention)
|
||||||
|
|
||||||
@ -292,6 +292,7 @@ class FlexAttentionMetadata:
|
|||||||
q_block_size: int = 16
|
q_block_size: int = 16
|
||||||
kv_block_size: int = 16
|
kv_block_size: int = 16
|
||||||
transformed_score_mod: Optional[_score_mod_signature] = None
|
transformed_score_mod: Optional[_score_mod_signature] = None
|
||||||
|
sliding_window: Optional[int] = None
|
||||||
|
|
||||||
def _convert_physical_to_logical(
|
def _convert_physical_to_logical(
|
||||||
self,
|
self,
|
||||||
@ -380,6 +381,53 @@ class FlexAttentionMetadata:
|
|||||||
|
|
||||||
return final_mask_mod
|
return final_mask_mod
|
||||||
|
|
||||||
|
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
|
||||||
|
"""Creates the sliding window mask_mod function for FlexAttention.
|
||||||
|
|
||||||
|
Note that the sliding window mask here is bidirectional, we need
|
||||||
|
to mask it with the bidirectional/causal mask for encoder/decoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.sliding_window is None:
|
||||||
|
raise ValueError(
|
||||||
|
"sliding_window must be set for sliding window attention")
|
||||||
|
|
||||||
|
def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor,
|
||||||
|
q_idx: torch.Tensor, kv_idx: torch.Tensor):
|
||||||
|
return torch.abs(q_idx - kv_idx) < self.sliding_window
|
||||||
|
|
||||||
|
def final_mask_mod(
|
||||||
|
b: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
q_idx: torch.Tensor,
|
||||||
|
physical_kv_idx: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
(is_valid, logical_q_idx,
|
||||||
|
logical_kv_idx) = self._convert_physical_to_logical(
|
||||||
|
self.doc_ids, q_idx, physical_kv_idx)
|
||||||
|
return torch.where(
|
||||||
|
is_valid,
|
||||||
|
sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_mask_mod if self.causal else sliding_window_mask_mod
|
||||||
|
|
||||||
|
def get_mask_mod(self):
|
||||||
|
# Stage-1: initialize the base mask_mod
|
||||||
|
# (causal mask for decoder or bidirectional mask for encoder)
|
||||||
|
if self.causal:
|
||||||
|
mask_mod = self.get_causal_mask_mod()
|
||||||
|
else:
|
||||||
|
mask_mod = self.get_bidirectional_mask_mod()
|
||||||
|
# stage-2: add external mask_mod for special attention during
|
||||||
|
# forwarding runtime to create the combined mask_mod.
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
# Add sliding window mask for sliding window attention
|
||||||
|
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
|
||||||
|
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
|
||||||
|
return mask_mod
|
||||||
|
|
||||||
def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
|
def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
|
||||||
"""Creates the transformed score_mod function for FlexAttention.
|
"""Creates the transformed score_mod function for FlexAttention.
|
||||||
|
|
||||||
@ -472,12 +520,9 @@ class FlexAttentionMetadata:
|
|||||||
return BlockMask.from_kv_blocks(**block_mask_kwargs)
|
return BlockMask.from_kv_blocks(**block_mask_kwargs)
|
||||||
|
|
||||||
def build_block_mask(self) -> BlockMask:
|
def build_block_mask(self) -> BlockMask:
|
||||||
if self.causal:
|
mask_mod = self.get_mask_mod()
|
||||||
mask_mod = self.get_causal_mask_mod()
|
kv_len = (self.total_cache_tokens
|
||||||
kv_len = self.total_cache_tokens
|
if self.causal else self.num_actual_tokens)
|
||||||
else:
|
|
||||||
mask_mod = self.get_bidirectional_mask_mod()
|
|
||||||
kv_len = self.num_actual_tokens
|
|
||||||
return create_block_mask_compiled(
|
return create_block_mask_compiled(
|
||||||
mask_mod,
|
mask_mod,
|
||||||
None,
|
None,
|
||||||
@ -498,11 +543,7 @@ class FlexAttentionMetadata:
|
|||||||
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||||
|
|
||||||
if self.causal:
|
self.mask_mod = self.get_mask_mod()
|
||||||
self.mask_mod = self.get_causal_mask_mod()
|
|
||||||
else:
|
|
||||||
self.mask_mod = self.get_bidirectional_mask_mod()
|
|
||||||
|
|
||||||
self.transformed_score_mod = self.get_transformed_score_mod()
|
self.transformed_score_mod = self.get_transformed_score_mod()
|
||||||
|
|
||||||
if self.direct_build and self.causal:
|
if self.direct_build and self.causal:
|
||||||
@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder(
|
|||||||
|
|
||||||
|
|
||||||
class FlexAttentionImpl(AttentionImpl):
|
class FlexAttentionImpl(AttentionImpl):
|
||||||
sliding_window: Optional[tuple[int, int]]
|
sliding_window: Optional[int]
|
||||||
alibi_slopes: Optional[torch.Tensor]
|
alibi_slopes: Optional[torch.Tensor]
|
||||||
logits_soft_cap: Optional[float]
|
logits_soft_cap: Optional[float]
|
||||||
|
|
||||||
@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
"FlexAttention does not support alibi slopes yet.")
|
"FlexAttention does not support alibi slopes yet.")
|
||||||
else:
|
else:
|
||||||
self.alibi_slopes = None
|
self.alibi_slopes = None
|
||||||
if sliding_window is not None:
|
|
||||||
raise NotImplementedError(
|
self.sliding_window = sliding_window
|
||||||
"FlexAttention does not support sliding window yet.")
|
|
||||||
else:
|
|
||||||
self.sliding_window = (-1, -1)
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
if self.logits_soft_cap is not None:
|
if self.logits_soft_cap is not None:
|
||||||
@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
|
if attn_metadata.sliding_window != self.sliding_window:
|
||||||
|
attn_metadata.sliding_window = self.sliding_window
|
||||||
|
if attn_metadata.direct_build:
|
||||||
|
# TODO: Support skipping the computation of sliding window
|
||||||
|
# in direct block mask building code path.
|
||||||
|
logger.warning_once(
|
||||||
|
"Using direct block mask building with sliding window, "
|
||||||
|
"which is suboptimal now. Performance may be degraded.")
|
||||||
|
# update mask mod in attention metadata
|
||||||
|
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||||
|
attn_metadata.block_mask = (
|
||||||
|
attn_metadata._build_block_mask_direct())
|
||||||
|
else:
|
||||||
|
attn_metadata.block_mask = attn_metadata.build_block_mask()
|
||||||
|
|
||||||
if not attn_metadata.causal:
|
if not attn_metadata.causal:
|
||||||
assert self.attn_type == AttentionType.ENCODER_ONLY
|
assert self.attn_type == AttentionType.ENCODER_ONLY
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user