mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:55:01 +08:00
Support encoder_only attention for FlexAttention (#22273)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
41b67f4263
commit
f825c6bd22
@ -9,7 +9,9 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
from ..models.utils import check_embeddings_close
|
||||||
|
|
||||||
TORCH_VERSION = version.parse(torch.__version__)
|
TORCH_VERSION = version.parse(torch.__version__)
|
||||||
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||||
@ -28,7 +30,7 @@ def set_seed(seed):
|
|||||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||||
reason="CUDA not available or PyTorch version < 2.7",
|
reason="CUDA not available or PyTorch version < 2.7",
|
||||||
)
|
)
|
||||||
def test_flex_attention_vs_default_backend(monkeypatch):
|
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||||
|
|
||||||
This test compares the outputs from the FlexAttention backend with
|
This test compares the outputs from the FlexAttention backend with
|
||||||
@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
|||||||
"""
|
"""
|
||||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
seed = 42
|
seed = 42
|
||||||
max_tokens = 32
|
max_tokens = 24
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
|||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||||
|
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
with vllm_runner(model_name,
|
||||||
llm_flex = LLM(
|
runner="generate",
|
||||||
model_name,
|
tensor_parallel_size=1,
|
||||||
tensor_parallel_size=1,
|
num_gpu_blocks_override=128,
|
||||||
num_gpu_blocks_override=128,
|
enforce_eager=True) as llm_flex:
|
||||||
enforce_eager=True,
|
output_flex = llm_flex.generate(prompts, sampling_params)
|
||||||
)
|
|
||||||
output_flex = llm_flex.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# Run with default backend
|
# Run with default backend
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
llm_default = LLM(
|
with vllm_runner(model_name,
|
||||||
model_name,
|
runner="generate",
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
num_gpu_blocks_override=128,
|
num_gpu_blocks_override=128,
|
||||||
enforce_eager=True,
|
enforce_eager=True) as llm_default:
|
||||||
)
|
output_default = llm_default.generate(prompts, sampling_params)
|
||||||
output_default = llm_default.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# Compare outputs from both backends
|
# Compare outputs from both backends
|
||||||
for i, (flex_result,
|
for i, (flex_result,
|
||||||
default_result) in enumerate(zip(output_flex, output_default)):
|
default_result) in enumerate(zip(output_flex, output_default)):
|
||||||
prompt = prompts[i]
|
prompt = prompts[i]
|
||||||
flex_text = flex_result.outputs[0].text
|
flex_text = flex_result[1][0]
|
||||||
default_text = default_result.outputs[0].text
|
default_text = default_result[1][0]
|
||||||
|
|
||||||
assert flex_text == default_text, (
|
assert flex_text == default_text, (
|
||||||
f"FlexAttention output doesn't match default for: {prompt!r}\n"
|
f"FlexAttention output doesn't match default for: {prompt!r}\n"
|
||||||
@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
|||||||
f"Default: {default_text!r}")
|
f"Default: {default_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||||
|
reason="CUDA not available or PyTorch version < 2.7",
|
||||||
|
)
|
||||||
|
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||||
|
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||||
|
|
||||||
|
This test compares the outputs from the FlexAttention backend with
|
||||||
|
the default backend for encoder models.
|
||||||
|
"""
|
||||||
|
model_name = "BAAI/bge-base-en-v1.5"
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run with flex attention
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||||
|
with vllm_runner(model_name,
|
||||||
|
runner="pooling",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_model_len=100,
|
||||||
|
enforce_eager=True) as llm_flex:
|
||||||
|
flex_outputs = llm_flex.embed(prompts)
|
||||||
|
|
||||||
|
# Run with default backend
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
with vllm_runner(model_name,
|
||||||
|
runner="pooling",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_model_len=100,
|
||||||
|
enforce_eager=True) as llm_default:
|
||||||
|
default_outputs = llm_default.embed(prompts)
|
||||||
|
|
||||||
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=flex_outputs,
|
||||||
|
embeddings_1_lst=default_outputs,
|
||||||
|
name_0="flex",
|
||||||
|
name_1="default",
|
||||||
|
tol=1e-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlexAttentionMetadata:
|
class FlexAttentionMetadata:
|
||||||
|
causal: bool
|
||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
max_query_len: int
|
max_query_len: int
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
@ -177,10 +178,9 @@ class FlexAttentionMetadata:
|
|||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
block_mask: Optional[BlockMask] = None
|
block_mask: Optional[BlockMask] = None
|
||||||
score_mod: Optional[_score_mod_signature] = None
|
score_mod: Optional[_score_mod_signature] = None
|
||||||
mask_mod: Optional[_mask_mod_signature] = None
|
|
||||||
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
||||||
|
|
||||||
def get_mask_mod(self) -> _mask_mod_signature:
|
def get_causal_mask_mod(self) -> _mask_mod_signature:
|
||||||
"""Creates the mask_mod function for FlexAttention.
|
"""Creates the mask_mod function for FlexAttention.
|
||||||
|
|
||||||
This function creates the combined mask mod function that handles:
|
This function creates the combined mask mod function that handles:
|
||||||
@ -233,14 +233,39 @@ class FlexAttentionMetadata:
|
|||||||
|
|
||||||
return final_mask_mod
|
return final_mask_mod
|
||||||
|
|
||||||
|
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
|
||||||
|
"""Creates the encoder mask_mod function for FlexAttention.
|
||||||
|
|
||||||
|
Since the encoder bidirectional attention doesn't run with
|
||||||
|
KV cache, this function creates a mask based on the
|
||||||
|
packed query sequences.
|
||||||
|
"""
|
||||||
|
# Create a lookup mapping from query indices -> request number
|
||||||
|
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||||
|
|
||||||
|
def final_mask_mod(
|
||||||
|
b: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
q_idx: torch.Tensor,
|
||||||
|
kv_idx: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return request_lookup[q_idx] == request_lookup[kv_idx]
|
||||||
|
|
||||||
|
return final_mask_mod
|
||||||
|
|
||||||
def build_block_mask(self) -> BlockMask:
|
def build_block_mask(self) -> BlockMask:
|
||||||
assert self.mask_mod is not None
|
if self.causal:
|
||||||
|
mask_mod = self.get_causal_mask_mod()
|
||||||
|
kv_len = self.total_cache_tokens
|
||||||
|
else:
|
||||||
|
mask_mod = self.get_bidirectional_mask_mod()
|
||||||
|
kv_len = self.num_actual_tokens
|
||||||
return create_block_mask_compiled(
|
return create_block_mask_compiled(
|
||||||
self.mask_mod,
|
mask_mod,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
self.num_actual_tokens,
|
self.num_actual_tokens,
|
||||||
self.total_cache_tokens,
|
kv_len,
|
||||||
device=self.block_table.device,
|
device=self.block_table.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -251,7 +276,6 @@ class FlexAttentionMetadata:
|
|||||||
assert self.prefix_kv_lens is None, "Not implemented yet."
|
assert self.prefix_kv_lens is None, "Not implemented yet."
|
||||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||||
self.mask_mod = self.get_mask_mod()
|
|
||||||
self.block_mask = self.build_block_mask()
|
self.block_mask = self.build_block_mask()
|
||||||
|
|
||||||
|
|
||||||
@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
|
|||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
|
|
||||||
out = FlexAttentionMetadata(
|
out = FlexAttentionMetadata(
|
||||||
|
causal=common_attn_metadata.causal,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
|
if attn_type not in (AttentionType.ENCODER_ONLY,
|
||||||
|
AttentionType.DECODER):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"FlexAttention does not support {attn_type} attention")
|
||||||
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
if not attn_metadata.causal:
|
||||||
|
assert self.attn_type == AttentionType.ENCODER_ONLY
|
||||||
|
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
query, key_tensor, value_tensor = map(
|
||||||
key,
|
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||||
value,
|
(query, key, value),
|
||||||
key_cache,
|
)
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
else:
|
||||||
self.kv_cache_dtype,
|
assert self.attn_type == AttentionType.DECODER
|
||||||
layer._k_scale,
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
layer._v_scale,
|
|
||||||
)
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# View out the block_size dim
|
||||||
|
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value_cache = value_cache.view(-1, self.num_kv_heads,
|
||||||
|
self.head_size)
|
||||||
|
query, key_tensor, value_tensor = map(
|
||||||
|
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||||
|
(query, key_cache, value_cache),
|
||||||
|
)
|
||||||
|
|
||||||
# View out the block_size dim
|
|
||||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
query, key_cache, value_cache = map(
|
|
||||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
|
||||||
(query, key_cache, value_cache),
|
|
||||||
)
|
|
||||||
query = query[:, :, :num_actual_tokens, :]
|
query = query[:, :, :num_actual_tokens, :]
|
||||||
# Doesn't work for now -> constraint violation
|
# Doesn't work for now -> constraint violation
|
||||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||||
@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
out = flex_attention_compiled(
|
out = flex_attention_compiled(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_tensor,
|
||||||
value_cache,
|
value_tensor,
|
||||||
attn_metadata.score_mod,
|
attn_metadata.score_mod,
|
||||||
attn_metadata.block_mask,
|
attn_metadata.block_mask,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user