mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
Enable hybrid attention models for Transformers backend (#18494)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
c6b636f9fb
commit
4b0da7b60e
@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
|
|||||||
|
|
||||||
To support a model with interleaving sliding windows, we need to take care of the following details:
|
To support a model with interleaving sliding windows, we need to take care of the following details:
|
||||||
|
|
||||||
- Make sure [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/config.py#L308) evaluates `has_interleaved_attention` to `True` for this model, and set `self.hf_text_config.interleaved_sliding_window` to the format of interleaving sliding windows the model can understand. Then, `self.hf_text_config.sliding_window` will be deleted, and the model will be treated as a full-attention model.
|
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
|
||||||
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
|
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
|
||||||
|
|
||||||
With these two steps, interleave sliding windows should work with the model.
|
With these two steps, interleave sliding windows should work with the model.
|
||||||
|
|||||||
@ -1,37 +1,50 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Test the functionality of the Transformers backend."""
|
"""Test the functionality of the Transformers backend."""
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..conftest import HfRunner, VllmRunner
|
from ..conftest import HfRunner, VllmRunner
|
||||||
|
from ..core.block.e2e.test_correctness_sliding_window import prep_prompts
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
|
|
||||||
def check_implementation(
|
def check_implementation(
|
||||||
hf_runner: type[HfRunner],
|
runner_ref: type[Union[HfRunner, VllmRunner]],
|
||||||
vllm_runner: type[VllmRunner],
|
runner_test: type[VllmRunner],
|
||||||
example_prompts: list[str],
|
example_prompts: list[str],
|
||||||
model: str,
|
model: str,
|
||||||
|
kwargs_ref: Optional[dict[str, Any]] = None,
|
||||||
|
kwargs_test: Optional[dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if kwargs_ref is None:
|
||||||
|
kwargs_ref = {}
|
||||||
|
if kwargs_test is None:
|
||||||
|
kwargs_test = {}
|
||||||
|
|
||||||
max_tokens = 32
|
max_tokens = 32
|
||||||
num_logprobs = 5
|
num_logprobs = 5
|
||||||
|
|
||||||
with vllm_runner(model, **kwargs) as vllm_model:
|
args = (example_prompts, max_tokens, num_logprobs)
|
||||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
|
||||||
example_prompts, max_tokens, num_logprobs)
|
|
||||||
|
|
||||||
with hf_runner(model) as hf_model:
|
with runner_test(model, **kwargs_test, **kwargs) as model_test:
|
||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
outputs_test = model_test.generate_greedy_logprobs(*args)
|
||||||
example_prompts, max_tokens, num_logprobs)
|
|
||||||
|
with runner_ref(model, **kwargs_ref) as model_ref:
|
||||||
|
if isinstance(model_ref, VllmRunner):
|
||||||
|
outputs_ref = model_ref.generate_greedy_logprobs(*args)
|
||||||
|
else:
|
||||||
|
outputs_ref = model_ref.generate_greedy_logprobs_limit(*args)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=outputs_ref,
|
||||||
outputs_1_lst=vllm_outputs,
|
outputs_1_lst=outputs_test,
|
||||||
name_0="hf",
|
name_0="ref",
|
||||||
name_1="vllm",
|
name_1="test",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -58,6 +71,18 @@ def test_models(
|
|||||||
model_impl=model_impl)
|
model_impl=model_impl)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None:
|
||||||
|
prompts, _, _ = prep_prompts(4, (800, 801))
|
||||||
|
kwargs_ref = {"max_model_len": 8192, "enforce_eager": True}
|
||||||
|
kwargs_test = {"model_impl": "transformers", **kwargs_ref}
|
||||||
|
check_implementation(vllm_runner,
|
||||||
|
vllm_runner,
|
||||||
|
prompts,
|
||||||
|
model="hmellor/tiny-random-Gemma2ForCausalLM",
|
||||||
|
kwargs_ref=kwargs_ref,
|
||||||
|
kwargs_test=kwargs_test)
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
def test_distributed(
|
def test_distributed(
|
||||||
hf_runner: type[HfRunner],
|
hf_runner: type[HfRunner],
|
||||||
@ -65,8 +90,11 @@ def test_distributed(
|
|||||||
example_prompts,
|
example_prompts,
|
||||||
):
|
):
|
||||||
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
|
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
|
||||||
check_implementation(hf_runner, vllm_runner, example_prompts,
|
check_implementation(hf_runner,
|
||||||
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
kwargs_test=kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
|||||||
@ -533,13 +533,17 @@ class ModelConfig:
|
|||||||
self.model, hf_token=self.hf_token, revision=self.revision)
|
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
|
||||||
|
|
||||||
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
|
# Workaround for Gemma 2 which uses interleaved sliding window
|
||||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
# attention, but it's not specified in its config. TODO: remove this
|
||||||
has_interleaved_attention = (sliding_window is not None) and (
|
# when Gemma 2 is fixed in Transformers.
|
||||||
isinstance(sliding_window, list) or
|
if self.hf_text_config.model_type == "gemma2":
|
||||||
(self.hf_text_config.model_type in interleaved_attn_models))
|
self.hf_text_config.sliding_window_pattern = 2
|
||||||
|
|
||||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||||
|
sliding_window_pattern = getattr(self.hf_text_config,
|
||||||
|
"sliding_window_pattern", None)
|
||||||
|
|
||||||
|
if not (self.disable_sliding_window or sliding_window_pattern is None):
|
||||||
if (backend :=
|
if (backend :=
|
||||||
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
|
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
|
||||||
sliding_window_len_min = get_min_sliding_window(
|
sliding_window_len_min = get_min_sliding_window(
|
||||||
@ -1037,8 +1041,7 @@ class ModelConfig:
|
|||||||
if self.use_async_output_proc:
|
if self.use_async_output_proc:
|
||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
|
|
||||||
def get_hf_config_sliding_window(
|
def get_hf_config_sliding_window(self) -> Optional[int]:
|
||||||
self) -> Union[Optional[int], list[Optional[int]]]:
|
|
||||||
"""Get the sliding window size, or None if disabled."""
|
"""Get the sliding window size, or None if disabled."""
|
||||||
|
|
||||||
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
||||||
@ -1049,7 +1052,7 @@ class ModelConfig:
|
|||||||
return None
|
return None
|
||||||
return getattr(self.hf_text_config, "sliding_window", None)
|
return getattr(self.hf_text_config, "sliding_window", None)
|
||||||
|
|
||||||
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
|
def get_sliding_window(self) -> Optional[int]:
|
||||||
"""Get the sliding window size, or None if disabled.
|
"""Get the sliding window size, or None if disabled.
|
||||||
"""
|
"""
|
||||||
# If user disables sliding window, return None.
|
# If user disables sliding window, return None.
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
"""Wrapper around `transformers` models"""
|
"""Wrapper around `transformers` models"""
|
||||||
import re
|
import re
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -110,6 +111,33 @@ def replace_linear_class(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigOverride:
|
||||||
|
"""Context manager to temporarily override config attributes."""
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig, **kwargs):
|
||||||
|
self.config = config
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.kwargs_original = {}
|
||||||
|
self.kwargs_delete = set()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Override config attributes."""
|
||||||
|
for key, value in self.kwargs.items():
|
||||||
|
if not hasattr(self.config, key):
|
||||||
|
self.kwargs_delete.add(key)
|
||||||
|
self.kwargs_original[key] = getattr(self.config, key, None)
|
||||||
|
setattr(self.config, key, value)
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
"""Restore original config attributes."""
|
||||||
|
for key, value in self.kwargs_original.items():
|
||||||
|
if key in self.kwargs_delete:
|
||||||
|
delattr(self.config, key)
|
||||||
|
else:
|
||||||
|
setattr(self.config, key, value)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(nn.Module):
|
class TransformersModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -135,8 +163,17 @@ class TransformersModel(nn.Module):
|
|||||||
self.pp_rank = self.pp_group.rank_in_group
|
self.pp_rank = self.pp_group.rank_in_group
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# vLLM handles interleaved sliding window attention by creating a new
|
||||||
|
# interleaved_sliding_window attribute and deleting the sliding_window
|
||||||
|
# attribute. This breaks the constructors in Transformers so we
|
||||||
|
# temporarily add the attribute back to construct the model.
|
||||||
|
config_override = nullcontext()
|
||||||
|
if hasattr(config, "interleaved_sliding_window"):
|
||||||
|
config_override = ConfigOverride(
|
||||||
|
config, sliding_window=config.interleaved_sliding_window)
|
||||||
|
|
||||||
# Use meta device to delay allocating GPU tensors
|
# Use meta device to delay allocating GPU tensors
|
||||||
with torch.device("meta"):
|
with torch.device("meta"), config_override:
|
||||||
# FIXME(Isotr0py): We need to refactor this part in the future to
|
# FIXME(Isotr0py): We need to refactor this part in the future to
|
||||||
# avoid registering an extra model layer, otherwise we will need a
|
# avoid registering an extra model layer, otherwise we will need a
|
||||||
# weights mapper to rename weights.
|
# weights mapper to rename weights.
|
||||||
@ -262,9 +299,17 @@ class TransformersModel(nn.Module):
|
|||||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
start, end = get_pp_indices(self.config.num_hidden_layers,
|
start, end = get_pp_indices(self.config.num_hidden_layers,
|
||||||
self.pp_rank, self.pp_size)
|
self.pp_rank, self.pp_size)
|
||||||
return {
|
|
||||||
i:
|
attention_instances = {}
|
||||||
Attention(
|
for i in range(start, end):
|
||||||
|
# Handle interleaved sliding window attention
|
||||||
|
sliding_window = None
|
||||||
|
if (hasattr(self.config, "interleaved_sliding_window")
|
||||||
|
and hasattr(self.config, "sliding_window_pattern")
|
||||||
|
and ((i + 1) % self.config.sliding_window_pattern > 0)):
|
||||||
|
sliding_window = self.config.interleaved_sliding_window
|
||||||
|
|
||||||
|
attention_instances[i] = Attention(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
# NOTE: We use Llama scale as default, if it's set by
|
# NOTE: We use Llama scale as default, if it's set by
|
||||||
@ -273,9 +318,9 @@ class TransformersModel(nn.Module):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
prefix=f"{i}.attn")
|
prefix=f"{i}.attn")
|
||||||
for i in range(start, end)
|
return attention_instances
|
||||||
}
|
|
||||||
|
|
||||||
def init_buffers(self, module: nn.Module):
|
def init_buffers(self, module: nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user