Support encoder_only attention for FlexAttention (#22273)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-08-06 22:37:14 -03:00 committed by GitHub
parent 41b67f4263
commit f825c6bd22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 46 deletions

View File

@ -9,7 +9,9 @@ import pytest
import torch import torch
from packaging import version from packaging import version
from vllm import LLM, SamplingParams from vllm import SamplingParams
from ..models.utils import check_embeddings_close
TORCH_VERSION = version.parse(torch.__version__) TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0") MINIMUM_TORCH_VERSION = version.parse("2.7.0")
@ -28,7 +30,7 @@ def set_seed(seed):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_flex_attention_vs_default_backend(monkeypatch): def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
""" """
model_name = "Qwen/Qwen2.5-1.5B-Instruct" model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42 seed = 42
max_tokens = 32 max_tokens = 24
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
set_seed(seed) set_seed(seed)
with vllm_runner(model_name,
llm_flex = LLM( runner="generate",
model_name, tensor_parallel_size=1,
tensor_parallel_size=1, num_gpu_blocks_override=128,
num_gpu_blocks_override=128, enforce_eager=True) as llm_flex:
enforce_eager=True, output_flex = llm_flex.generate(prompts, sampling_params)
)
output_flex = llm_flex.generate(prompts, sampling_params)
# Run with default backend # Run with default backend
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
set_seed(seed) set_seed(seed)
llm_default = LLM( with vllm_runner(model_name,
model_name, runner="generate",
tensor_parallel_size=1, tensor_parallel_size=1,
num_gpu_blocks_override=128, num_gpu_blocks_override=128,
enforce_eager=True, enforce_eager=True) as llm_default:
) output_default = llm_default.generate(prompts, sampling_params)
output_default = llm_default.generate(prompts, sampling_params)
# Compare outputs from both backends # Compare outputs from both backends
for i, (flex_result, for i, (flex_result,
default_result) in enumerate(zip(output_flex, output_default)): default_result) in enumerate(zip(output_flex, output_default)):
prompt = prompts[i] prompt = prompts[i]
flex_text = flex_result.outputs[0].text flex_text = flex_result[1][0]
default_text = default_result.outputs[0].text default_text = default_result[1][0]
assert flex_text == default_text, ( assert flex_text == default_text, (
f"FlexAttention output doesn't match default for: {prompt!r}\n" f"FlexAttention output doesn't match default for: {prompt!r}\n"
@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
f"Default: {default_text!r}") f"Default: {default_text!r}")
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend for encoder models.
"""
model_name = "BAAI/bge-base-en-v1.5"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
# Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_flex:
flex_outputs = llm_flex.embed(prompts)
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_default:
default_outputs = llm_default.embed(prompts)
check_embeddings_close(
embeddings_0_lst=flex_outputs,
embeddings_1_lst=default_outputs,
name_0="flex",
name_1="default",
tol=1e-2,
)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
@dataclass @dataclass
class FlexAttentionMetadata: class FlexAttentionMetadata:
causal: bool
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int max_query_len: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
@ -177,10 +178,9 @@ class FlexAttentionMetadata:
num_blocks = 0 num_blocks = 0
block_mask: Optional[BlockMask] = None block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None score_mod: Optional[_score_mod_signature] = None
mask_mod: Optional[_mask_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod logical_mask_mod: _mask_mod_signature = causal_mask_mod
def get_mask_mod(self) -> _mask_mod_signature: def get_causal_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention. """Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles: This function creates the combined mask mod function that handles:
@ -233,14 +233,39 @@ class FlexAttentionMetadata:
return final_mask_mod return final_mask_mod
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
return request_lookup[q_idx] == request_lookup[kv_idx]
return final_mask_mod
def build_block_mask(self) -> BlockMask: def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None if self.causal:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
else:
mask_mod = self.get_bidirectional_mask_mod()
kv_len = self.num_actual_tokens
return create_block_mask_compiled( return create_block_mask_compiled(
self.mask_mod, mask_mod,
None, None,
None, None,
self.num_actual_tokens, self.num_actual_tokens,
self.total_cache_tokens, kv_len,
device=self.block_table.device, device=self.block_table.device,
) )
@ -251,7 +276,6 @@ class FlexAttentionMetadata:
assert self.prefix_kv_lens is None, "Not implemented yet." assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet."
self.num_blocks = self.total_cache_tokens // self.block_size self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
self.block_mask = self.build_block_mask() self.block_mask = self.build_block_mask()
@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
self.device, non_blocking=True) self.device, non_blocking=True)
out = FlexAttentionMetadata( out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.attn_type = attn_type
if attn_type not in (AttentionType.ENCODER_ONLY,
AttentionType.DECODER):
raise NotImplementedError(
f"FlexAttention does not support {attn_type} attention")
if alibi_slopes is not None: if alibi_slopes is not None:
raise NotImplementedError( raise NotImplementedError(
@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0) if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY
torch.ops._C_cache_ops.reshape_and_cache_flash( query, key_tensor, value_tensor = map(
key, lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
value, (query, key, value),
key_cache, )
value_cache,
attn_metadata.slot_mapping, else:
self.kv_cache_dtype, assert self.attn_type == AttentionType.DECODER
layer._k_scale, key_cache, value_cache = kv_cache.unbind(0)
layer._v_scale,
) torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads,
self.head_size)
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
query, key_cache, value_cache = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :] query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation # Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2) # torch._dynamo.try_mark_dynamic(query, 2)
@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
out = flex_attention_compiled( out = flex_attention_compiled(
query, query,
key_cache, key_tensor,
value_cache, value_tensor,
attn_metadata.score_mod, attn_metadata.score_mod,
attn_metadata.block_mask, attn_metadata.block_mask,
self.scale, self.scale,