Enable hybrid attention models for Transformers backend (#18494)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-23 04:12:08 +02:00 committed by GitHub
parent c6b636f9fb
commit 4b0da7b60e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 106 additions and 30 deletions

View File

@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
To support a model with interleaving sliding windows, we need to take care of the following details:
- Make sure [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/config.py#L308) evaluates `has_interleaved_attention` to `True` for this model, and set `self.hf_text_config.interleaved_sliding_window` to the format of interleaving sliding windows the model can understand. Then, `self.hf_text_config.sliding_window` will be deleted, and the model will be treated as a full-attention model.
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
With these two steps, interleave sliding windows should work with the model.

View File

@ -1,37 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
"""Test the functionality of the Transformers backend."""
from typing import Any, Optional, Union
import pytest
from vllm.platforms import current_platform
from ..conftest import HfRunner, VllmRunner
from ..core.block.e2e.test_correctness_sliding_window import prep_prompts
from ..utils import multi_gpu_test
from .utils import check_logprobs_close
def check_implementation(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
runner_ref: type[Union[HfRunner, VllmRunner]],
runner_test: type[VllmRunner],
example_prompts: list[str],
model: str,
kwargs_ref: Optional[dict[str, Any]] = None,
kwargs_test: Optional[dict[str, Any]] = None,
**kwargs,
):
if kwargs_ref is None:
kwargs_ref = {}
if kwargs_test is None:
kwargs_test = {}
max_tokens = 32
num_logprobs = 5
with vllm_runner(model, **kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
args = (example_prompts, max_tokens, num_logprobs)
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with runner_test(model, **kwargs_test, **kwargs) as model_test:
outputs_test = model_test.generate_greedy_logprobs(*args)
with runner_ref(model, **kwargs_ref) as model_ref:
if isinstance(model_ref, VllmRunner):
outputs_ref = model_ref.generate_greedy_logprobs(*args)
else:
outputs_ref = model_ref.generate_greedy_logprobs_limit(*args)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
outputs_0_lst=outputs_ref,
outputs_1_lst=outputs_test,
name_0="ref",
name_1="test",
)
@ -58,6 +71,18 @@ def test_models(
model_impl=model_impl)
def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None:
prompts, _, _ = prep_prompts(4, (800, 801))
kwargs_ref = {"max_model_len": 8192, "enforce_eager": True}
kwargs_test = {"model_impl": "transformers", **kwargs_ref}
check_implementation(vllm_runner,
vllm_runner,
prompts,
model="hmellor/tiny-random-Gemma2ForCausalLM",
kwargs_ref=kwargs_ref,
kwargs_test=kwargs_test)
@multi_gpu_test(num_gpus=2)
def test_distributed(
hf_runner: type[HfRunner],
@ -65,8 +90,11 @@ def test_distributed(
example_prompts,
):
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
check_implementation(hf_runner, vllm_runner, example_prompts,
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
check_implementation(hf_runner,
vllm_runner,
example_prompts,
"meta-llama/Llama-3.2-1B-Instruct",
kwargs_test=kwargs)
@pytest.mark.skipif(

View File

@ -533,13 +533,17 @@ class ModelConfig:
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in interleaved_attn_models))
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
# when Gemma 2 is fixed in Transformers.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2
if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
if not (self.disable_sliding_window or sliding_window_pattern is None):
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
@ -1037,8 +1041,7 @@ class ModelConfig:
if self.use_async_output_proc:
self.use_async_output_proc = False
def get_hf_config_sliding_window(
self) -> Union[Optional[int], list[Optional[int]]]:
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@ -1049,7 +1052,7 @@ class ModelConfig:
return None
return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.

View File

@ -16,6 +16,7 @@
"""Wrapper around `transformers` models"""
import re
from collections.abc import Iterable
from contextlib import nullcontext
from typing import Literal, Optional, Union
import torch
@ -110,6 +111,33 @@ def replace_linear_class(
)
class ConfigOverride:
"""Context manager to temporarily override config attributes."""
def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()
def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config
def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)
class TransformersModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -135,8 +163,17 @@ class TransformersModel(nn.Module):
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)
# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
with torch.device("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
@ -262,9 +299,17 @@ class TransformersModel(nn.Module):
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window")
and hasattr(self.config, "sliding_window_pattern")
and ((i + 1) % self.config.sliding_window_pattern > 0)):
sliding_window = self.config.interleaved_sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
@ -273,9 +318,9 @@ class TransformersModel(nn.Module):
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{i}.attn")
for i in range(start, end)
}
return attention_instances
def init_buffers(self, module: nn.Module):
"""