Support encoder_only attention for FlexAttention (#22273)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-08-06 22:37:14 -03:00 committed by GitHub
parent 41b67f4263
commit f825c6bd22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 46 deletions

View File

@ -9,7 +9,9 @@ import pytest
import torch
from packaging import version
from vllm import LLM, SamplingParams
from vllm import SamplingParams
from ..models.utils import check_embeddings_close
TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
@ -28,7 +30,7 @@ def set_seed(seed):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_flex_attention_vs_default_backend(monkeypatch):
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
"""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42
max_tokens = 32
max_tokens = 24
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
set_seed(seed)
llm_flex = LLM(
model_name,
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
)
output_flex = llm_flex.generate(prompts, sampling_params)
with vllm_runner(model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True) as llm_flex:
output_flex = llm_flex.generate(prompts, sampling_params)
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
set_seed(seed)
llm_default = LLM(
model_name,
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
)
output_default = llm_default.generate(prompts, sampling_params)
with vllm_runner(model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True) as llm_default:
output_default = llm_default.generate(prompts, sampling_params)
# Compare outputs from both backends
for i, (flex_result,
default_result) in enumerate(zip(output_flex, output_default)):
prompt = prompts[i]
flex_text = flex_result.outputs[0].text
default_text = default_result.outputs[0].text
flex_text = flex_result[1][0]
default_text = default_result[1][0]
assert flex_text == default_text, (
f"FlexAttention output doesn't match default for: {prompt!r}\n"
@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
f"Default: {default_text!r}")
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend for encoder models.
"""
model_name = "BAAI/bge-base-en-v1.5"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
# Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_flex:
flex_outputs = llm_flex.embed(prompts)
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_default:
default_outputs = llm_default.embed(prompts)
check_embeddings_close(
embeddings_0_lst=flex_outputs,
embeddings_1_lst=default_outputs,
name_0="flex",
name_1="default",
tol=1e-2,
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
@dataclass
class FlexAttentionMetadata:
causal: bool
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
@ -177,10 +178,9 @@ class FlexAttentionMetadata:
num_blocks = 0
block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None
mask_mod: Optional[_mask_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod
def get_mask_mod(self) -> _mask_mod_signature:
def get_causal_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
@ -233,14 +233,39 @@ class FlexAttentionMetadata:
return final_mask_mod
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
return request_lookup[q_idx] == request_lookup[kv_idx]
return final_mask_mod
def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
if self.causal:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
else:
mask_mod = self.get_bidirectional_mask_mod()
kv_len = self.num_actual_tokens
return create_block_mask_compiled(
self.mask_mod,
mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
kv_len,
device=self.block_table.device,
)
@ -251,7 +276,6 @@ class FlexAttentionMetadata:
assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet."
self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
self.block_mask = self.build_block_mask()
@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
self.device, non_blocking=True)
out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.attn_type = attn_type
if attn_type not in (AttentionType.ENCODER_ONLY,
AttentionType.DECODER):
raise NotImplementedError(
f"FlexAttention does not support {attn_type} attention")
if alibi_slopes is not None:
raise NotImplementedError(
@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key, value),
)
else:
assert self.attn_type == AttentionType.DECODER
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads,
self.head_size)
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
query, key_cache, value_cache = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
out = flex_attention_compiled(
query,
key_cache,
value_cache,
key_tensor,
value_tensor,
attn_metadata.score_mod,
attn_metadata.block_mask,
self.scale,