Enable hybrid attention models for Transformers backend (#18494)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-23 04:12:08 +02:00 committed by GitHub
parent c6b636f9fb
commit 4b0da7b60e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 106 additions and 30 deletions

View File

@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
To support a model with interleaving sliding windows, we need to take care of the following details: To support a model with interleaving sliding windows, we need to take care of the following details:
- Make sure [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/config.py#L308) evaluates `has_interleaved_attention` to `True` for this model, and set `self.hf_text_config.interleaved_sliding_window` to the format of interleaving sliding windows the model can understand. Then, `self.hf_text_config.sliding_window` will be deleted, and the model will be treated as a full-attention model. - Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
With these two steps, interleave sliding windows should work with the model. With these two steps, interleave sliding windows should work with the model.

View File

@ -1,37 +1,50 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Test the functionality of the Transformers backend.""" """Test the functionality of the Transformers backend."""
from typing import Any, Optional, Union
import pytest import pytest
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..conftest import HfRunner, VllmRunner from ..conftest import HfRunner, VllmRunner
from ..core.block.e2e.test_correctness_sliding_window import prep_prompts
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
from .utils import check_logprobs_close from .utils import check_logprobs_close
def check_implementation( def check_implementation(
hf_runner: type[HfRunner], runner_ref: type[Union[HfRunner, VllmRunner]],
vllm_runner: type[VllmRunner], runner_test: type[VllmRunner],
example_prompts: list[str], example_prompts: list[str],
model: str, model: str,
kwargs_ref: Optional[dict[str, Any]] = None,
kwargs_test: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
): ):
if kwargs_ref is None:
kwargs_ref = {}
if kwargs_test is None:
kwargs_test = {}
max_tokens = 32 max_tokens = 32
num_logprobs = 5 num_logprobs = 5
with vllm_runner(model, **kwargs) as vllm_model: args = (example_prompts, max_tokens, num_logprobs)
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with hf_runner(model) as hf_model: with runner_test(model, **kwargs_test, **kwargs) as model_test:
hf_outputs = hf_model.generate_greedy_logprobs_limit( outputs_test = model_test.generate_greedy_logprobs(*args)
example_prompts, max_tokens, num_logprobs)
with runner_ref(model, **kwargs_ref) as model_ref:
if isinstance(model_ref, VllmRunner):
outputs_ref = model_ref.generate_greedy_logprobs(*args)
else:
outputs_ref = model_ref.generate_greedy_logprobs_limit(*args)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=outputs_ref,
outputs_1_lst=vllm_outputs, outputs_1_lst=outputs_test,
name_0="hf", name_0="ref",
name_1="vllm", name_1="test",
) )
@ -58,6 +71,18 @@ def test_models(
model_impl=model_impl) model_impl=model_impl)
def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None:
prompts, _, _ = prep_prompts(4, (800, 801))
kwargs_ref = {"max_model_len": 8192, "enforce_eager": True}
kwargs_test = {"model_impl": "transformers", **kwargs_ref}
check_implementation(vllm_runner,
vllm_runner,
prompts,
model="hmellor/tiny-random-Gemma2ForCausalLM",
kwargs_ref=kwargs_ref,
kwargs_test=kwargs_test)
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
def test_distributed( def test_distributed(
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
@ -65,8 +90,11 @@ def test_distributed(
example_prompts, example_prompts,
): ):
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
check_implementation(hf_runner, vllm_runner, example_prompts, check_implementation(hf_runner,
"meta-llama/Llama-3.2-1B-Instruct", **kwargs) vllm_runner,
example_prompts,
"meta-llama/Llama-3.2-1B-Instruct",
kwargs_test=kwargs)
@pytest.mark.skipif( @pytest.mark.skipif(

View File

@ -533,13 +533,17 @@ class ModelConfig:
self.model, hf_token=self.hf_token, revision=self.revision) self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] # Workaround for Gemma 2 which uses interleaved sliding window
sliding_window = getattr(self.hf_text_config, "sliding_window", None) # attention, but it's not specified in its config. TODO: remove this
has_interleaved_attention = (sliding_window is not None) and ( # when Gemma 2 is fixed in Transformers.
isinstance(sliding_window, list) or if self.hf_text_config.model_type == "gemma2":
(self.hf_text_config.model_type in interleaved_attn_models)) self.hf_text_config.sliding_window_pattern = 2
if (not self.disable_sliding_window and has_interleaved_attention): sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
if not (self.disable_sliding_window or sliding_window_pattern is None):
if (backend := if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window( sliding_window_len_min = get_min_sliding_window(
@ -1037,8 +1041,7 @@ class ModelConfig:
if self.use_async_output_proc: if self.use_async_output_proc:
self.use_async_output_proc = False self.use_async_output_proc = False
def get_hf_config_sliding_window( def get_hf_config_sliding_window(self) -> Optional[int]:
self) -> Union[Optional[int], list[Optional[int]]]:
"""Get the sliding window size, or None if disabled.""" """Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@ -1049,7 +1052,7 @@ class ModelConfig:
return None return None
return getattr(self.hf_text_config, "sliding_window", None) return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled. """Get the sliding window size, or None if disabled.
""" """
# If user disables sliding window, return None. # If user disables sliding window, return None.

View File

@ -16,6 +16,7 @@
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
import re import re
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import nullcontext
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import torch import torch
@ -110,6 +111,33 @@ def replace_linear_class(
) )
class ConfigOverride:
"""Context manager to temporarily override config attributes."""
def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()
def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config
def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)
class TransformersModel(nn.Module): class TransformersModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -135,8 +163,17 @@ class TransformersModel(nn.Module):
self.pp_rank = self.pp_group.rank_in_group self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)
# Use meta device to delay allocating GPU tensors # Use meta device to delay allocating GPU tensors
with torch.device("meta"): with torch.device("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to # FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a # avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights. # weights mapper to rename weights.
@ -262,9 +299,17 @@ class TransformersModel(nn.Module):
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers, start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size) self.pp_rank, self.pp_size)
return {
i: attention_instances = {}
Attention( for i in range(start, end):
# Handle interleaved sliding window attention
sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window")
and hasattr(self.config, "sliding_window_pattern")
and ((i + 1) % self.config.sliding_window_pattern > 0)):
sliding_window = self.config.interleaved_sliding_window
attention_instances[i] = Attention(
num_heads=num_heads, num_heads=num_heads,
head_size=head_size, head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by # NOTE: We use Llama scale as default, if it's set by
@ -273,9 +318,9 @@ class TransformersModel(nn.Module):
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
cache_config=self.cache_config, cache_config=self.cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{i}.attn") prefix=f"{i}.attn")
for i in range(start, end) return attention_instances
}
def init_buffers(self, module: nn.Module): def init_buffers(self, module: nn.Module):
""" """