From f825c6bd22133a8b2242457069f59654a2ae401b Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Wed, 6 Aug 2025 22:37:14 -0300 Subject: [PATCH] Support encoder_only attention for FlexAttention (#22273) Signed-off-by: Max de Bayser --- tests/kernels/test_flex_attention.py | 88 +++++++++++++----- vllm/v1/attention/backends/flex_attention.py | 95 ++++++++++++++------ 2 files changed, 137 insertions(+), 46 deletions(-) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index e25556c89fb9..f76bd192460c 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -9,7 +9,9 @@ import pytest import torch from packaging import version -from vllm import LLM, SamplingParams +from vllm import SamplingParams + +from ..models.utils import check_embeddings_close TORCH_VERSION = version.parse(torch.__version__) MINIMUM_TORCH_VERSION = version.parse("2.7.0") @@ -28,7 +30,7 @@ def set_seed(seed): not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, reason="CUDA not available or PyTorch version < 2.7", ) -def test_flex_attention_vs_default_backend(monkeypatch): +def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with @@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch): """ model_name = "Qwen/Qwen2.5-1.5B-Instruct" seed = 42 - max_tokens = 32 + max_tokens = 24 prompts = [ "Hello, my name is", "The president of the United States is", @@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch): m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") set_seed(seed) - - llm_flex = LLM( - model_name, - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True, - ) - output_flex = llm_flex.generate(prompts, sampling_params) + with vllm_runner(model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True) as llm_flex: + output_flex = llm_flex.generate(prompts, sampling_params) # Run with default backend with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") set_seed(seed) - llm_default = LLM( - model_name, - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True, - ) - output_default = llm_default.generate(prompts, sampling_params) + with vllm_runner(model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True) as llm_default: + output_default = llm_default.generate(prompts, sampling_params) # Compare outputs from both backends for i, (flex_result, default_result) in enumerate(zip(output_flex, output_default)): prompt = prompts[i] - flex_text = flex_result.outputs[0].text - default_text = default_result.outputs[0].text + flex_text = flex_result[1][0] + default_text = default_result[1][0] assert flex_text == default_text, ( f"FlexAttention output doesn't match default for: {prompt!r}\n" @@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch): f"Default: {default_text!r}") +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): + """Test that FlexAttention produces the same outputs as the default backend. + + This test compares the outputs from the FlexAttention backend with + the default backend for encoder models. + """ + model_name = "BAAI/bge-base-en-v1.5" + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + ] + + # Run with flex attention + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + with vllm_runner(model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True) as llm_flex: + flex_outputs = llm_flex.embed(prompts) + + # Run with default backend + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + with vllm_runner(model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True) as llm_default: + default_outputs = llm_default.embed(prompts) + + check_embeddings_close( + embeddings_0_lst=flex_outputs, + embeddings_1_lst=default_outputs, + name_0="flex", + name_1="default", + tol=1e-2, + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index bb0d890c7754..e599411b2d7e 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, @dataclass class FlexAttentionMetadata: + causal: bool num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor @@ -177,10 +178,9 @@ class FlexAttentionMetadata: num_blocks = 0 block_mask: Optional[BlockMask] = None score_mod: Optional[_score_mod_signature] = None - mask_mod: Optional[_mask_mod_signature] = None logical_mask_mod: _mask_mod_signature = causal_mask_mod - def get_mask_mod(self) -> _mask_mod_signature: + def get_causal_mask_mod(self) -> _mask_mod_signature: """Creates the mask_mod function for FlexAttention. This function creates the combined mask mod function that handles: @@ -233,14 +233,39 @@ class FlexAttentionMetadata: return final_mask_mod + def get_bidirectional_mask_mod(self) -> _mask_mod_signature: + """Creates the encoder mask_mod function for FlexAttention. + + Since the encoder bidirectional attention doesn't run with + KV cache, this function creates a mask based on the + packed query sequences. + """ + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + return request_lookup[q_idx] == request_lookup[kv_idx] + + return final_mask_mod + def build_block_mask(self) -> BlockMask: - assert self.mask_mod is not None + if self.causal: + mask_mod = self.get_causal_mask_mod() + kv_len = self.total_cache_tokens + else: + mask_mod = self.get_bidirectional_mask_mod() + kv_len = self.num_actual_tokens return create_block_mask_compiled( - self.mask_mod, + mask_mod, None, None, self.num_actual_tokens, - self.total_cache_tokens, + kv_len, device=self.block_table.device, ) @@ -251,7 +276,6 @@ class FlexAttentionMetadata: assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." self.num_blocks = self.total_cache_tokens // self.block_size - self.mask_mod = self.get_mask_mod() self.block_mask = self.build_block_mask() @@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder( self.device, non_blocking=True) out = FlexAttentionMetadata( + causal=common_attn_metadata.causal, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, @@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl): self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads + self.attn_type = attn_type + + if attn_type not in (AttentionType.ENCODER_ONLY, + AttentionType.DECODER): + raise NotImplementedError( + f"FlexAttention does not support {attn_type} attention") if alibi_slopes is not None: raise NotImplementedError( @@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens - key_cache, value_cache = kv_cache.unbind(0) + if not attn_metadata.causal: + assert self.attn_type == AttentionType.ENCODER_ONLY - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + query, key_tensor, value_tensor = map( + lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), + (query, key, value), + ) + + else: + assert self.attn_type == AttentionType.DECODER + key_cache, value_cache = kv_cache.unbind(0) + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # View out the block_size dim + key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) + value_cache = value_cache.view(-1, self.num_kv_heads, + self.head_size) + query, key_tensor, value_tensor = map( + lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), + (query, key_cache, value_cache), + ) - # View out the block_size dim - key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) - value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) - query, key_cache, value_cache = map( - lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), - (query, key_cache, value_cache), - ) query = query[:, :, :num_actual_tokens, :] # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) @@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl): out = flex_attention_compiled( query, - key_cache, - value_cache, + key_tensor, + value_tensor, attn_metadata.score_mod, attn_metadata.block_mask, self.scale,