mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
Support encoder_only attention for FlexAttention (#22273)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
41b67f4263
commit
f825c6bd22
@ -9,7 +9,9 @@ import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ..models.utils import check_embeddings_close
|
||||
|
||||
TORCH_VERSION = version.parse(torch.__version__)
|
||||
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
@ -28,7 +30,7 @@ def set_seed(seed):
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_flex_attention_vs_default_backend(monkeypatch):
|
||||
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
||||
"""
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
seed = 42
|
||||
max_tokens = 32
|
||||
max_tokens = 24
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
llm_flex = LLM(
|
||||
model_name,
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
)
|
||||
output_flex = llm_flex.generate(prompts, sampling_params)
|
||||
with vllm_runner(model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_flex:
|
||||
output_flex = llm_flex.generate(prompts, sampling_params)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
set_seed(seed)
|
||||
llm_default = LLM(
|
||||
model_name,
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
)
|
||||
output_default = llm_default.generate(prompts, sampling_params)
|
||||
with vllm_runner(model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_default:
|
||||
output_default = llm_default.generate(prompts, sampling_params)
|
||||
|
||||
# Compare outputs from both backends
|
||||
for i, (flex_result,
|
||||
default_result) in enumerate(zip(output_flex, output_default)):
|
||||
prompt = prompts[i]
|
||||
flex_text = flex_result.outputs[0].text
|
||||
default_text = default_result.outputs[0].text
|
||||
flex_text = flex_result[1][0]
|
||||
default_text = default_result[1][0]
|
||||
|
||||
assert flex_text == default_text, (
|
||||
f"FlexAttention output doesn't match default for: {prompt!r}\n"
|
||||
@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
||||
f"Default: {default_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
the default backend for encoder models.
|
||||
"""
|
||||
model_name = "BAAI/bge-base-en-v1.5"
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
with vllm_runner(model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True) as llm_flex:
|
||||
flex_outputs = llm_flex.embed(prompts)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True) as llm_default:
|
||||
default_outputs = llm_default.embed(prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=flex_outputs,
|
||||
embeddings_1_lst=default_outputs,
|
||||
name_0="flex",
|
||||
name_1="default",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
causal: bool
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
@ -177,10 +178,9 @@ class FlexAttentionMetadata:
|
||||
num_blocks = 0
|
||||
block_mask: Optional[BlockMask] = None
|
||||
score_mod: Optional[_score_mod_signature] = None
|
||||
mask_mod: Optional[_mask_mod_signature] = None
|
||||
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
||||
|
||||
def get_mask_mod(self) -> _mask_mod_signature:
|
||||
def get_causal_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the mask_mod function for FlexAttention.
|
||||
|
||||
This function creates the combined mask mod function that handles:
|
||||
@ -233,14 +233,39 @@ class FlexAttentionMetadata:
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the encoder mask_mod function for FlexAttention.
|
||||
|
||||
Since the encoder bidirectional attention doesn't run with
|
||||
KV cache, this function creates a mask based on the
|
||||
packed query sequences.
|
||||
"""
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return request_lookup[q_idx] == request_lookup[kv_idx]
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
assert self.mask_mod is not None
|
||||
if self.causal:
|
||||
mask_mod = self.get_causal_mask_mod()
|
||||
kv_len = self.total_cache_tokens
|
||||
else:
|
||||
mask_mod = self.get_bidirectional_mask_mod()
|
||||
kv_len = self.num_actual_tokens
|
||||
return create_block_mask_compiled(
|
||||
self.mask_mod,
|
||||
mask_mod,
|
||||
None,
|
||||
None,
|
||||
self.num_actual_tokens,
|
||||
self.total_cache_tokens,
|
||||
kv_len,
|
||||
device=self.block_table.device,
|
||||
)
|
||||
|
||||
@ -251,7 +276,6 @@ class FlexAttentionMetadata:
|
||||
assert self.prefix_kv_lens is None, "Not implemented yet."
|
||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
causal=common_attn_metadata.causal,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.attn_type = attn_type
|
||||
|
||||
if attn_type not in (AttentionType.ENCODER_ONLY,
|
||||
AttentionType.DECODER):
|
||||
raise NotImplementedError(
|
||||
f"FlexAttention does not support {attn_type} attention")
|
||||
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if not attn_metadata.causal:
|
||||
assert self.attn_type == AttentionType.ENCODER_ONLY
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key, value),
|
||||
)
|
||||
|
||||
else:
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads,
|
||||
self.head_size)
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
query, key_cache, value_cache = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attn_metadata.score_mod,
|
||||
attn_metadata.block_mask,
|
||||
self.scale,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user