Encoder model support for the Transformers backend (#25174)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-19 19:15:22 +01:00 committed by GitHub
parent d90e212a3a
commit 12aed7e453
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 111 additions and 24 deletions

View File

@ -17,9 +17,24 @@ These models are what we list in [supported-text-models][supported-text-models]
### Transformers
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases.
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
To check if the modeling backend is Transformers, you can simply do this:
Currently, the Transformers backend works for the following:
- Modalities: embedding models, language models and vision-language models*
- Architectures: encoder-only, decoder-only
- Attention types: full attention and/or sliding attention
_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._
If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers backend, it will be compatible with the following features of vLLM:
- All the features listed in the [compatibility matrix](../features/compatibility_matrix.md#feature-x-feature)
- Any combination of the following vLLM parallelisation schemes:
- Pipeline parallel
- Tensor parallel
Checking if the modeling backend is Transformers is as simple as:
```python
from vllm import LLM
@ -27,16 +42,12 @@ llm = LLM(model=...) # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
```
If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers!
If the printed type starts with `Transformers...` then it's using the Transformers model implementation!
!!! tip
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md).
If a model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md).
!!! note
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
!!! note
In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
For vision-language models, if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
#### Custom models
@ -66,10 +77,11 @@ This section details the necessary modifications to make to a Transformers compa
To make your model compatible with the Transformers backend, it needs:
1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`.
1. If your model is encoder-only, you must also add `is_causal = False` to `MyAttention`.
2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention.
3. `MyModel` must contain `_supports_attention_backend = True`.
<details>
<details class="code">
<summary>modeling_my_model.py</summary>
```python
@ -78,6 +90,7 @@ from transformers import PreTrainedModel
from torch import nn
class MyAttention(nn.Module):
is_causal = False # Only do this for encoder-only models
def forward(self, hidden_states, **kwargs):
...
@ -101,13 +114,13 @@ Here is what happens in the background when this model is loaded:
1. The config is loaded.
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
3. `MyModel` is loaded into one of the Transformers backend classes in <gh-file:vllm/model_executor/models/transformers.py> which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
That's it!
For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class:
<details>
<details class="code">
<summary>configuration_my_model.py</summary>
```python

View File

@ -9,7 +9,7 @@ from vllm.platforms import current_platform
from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test, prep_prompts
from .utils import check_logprobs_close
from .utils import check_embeddings_close, check_logprobs_close
def check_implementation(
@ -165,6 +165,40 @@ def test_embed_loading(vllm_runner, model):
assert model_config.using_transformers_backend()
@pytest.mark.parametrize(
"model",
[
# Encoder model
"BAAI/bge-base-en-v1.5",
])
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
pytest.skip("Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
with vllm_runner(model, max_model_len=512,
model_impl="transformers") as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend()
vllm_outputs = vllm_model.embed(example_prompts)
with hf_runner(model, is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],

View File

@ -23,14 +23,14 @@ class AttentionType:
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
"""Decoder attention between previous layer Q/K/V."""
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
"""Encoder attention between previous layer Q/K/V."""
ENCODER_DECODER = "encoder_decoder"
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
class AttentionBackend(ABC):

View File

@ -27,7 +27,7 @@ from transformers import (AutoModel, BatchFeature, PretrainedConfig,
PreTrainedModel)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
@ -452,8 +452,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# To be updated in child classes for use in `load_weights`
self.skip_prefixes: Optional[list[str]] = None
# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
self.skip_substrs: list[str] = []
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
@ -596,7 +597,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
_tensor_parallel(self.model)
def create_attention_instances(self) -> dict[int, Attention]:
def create_attention_instances(
self,
attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
@ -625,7 +629,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn")
prefix=f"{i}.attn",
attn_type=attn_type)
return attention_instances
def init_parameters(self, module: nn.Module):
@ -685,7 +690,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes)
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
skip_substrs=self.skip_substrs,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@ -700,6 +709,37 @@ class TransformersModel(TransformersBase):
"model.score": "score",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Some encoder models have the position_ids buffer in the checkpoint
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
return super().create_attention_instances(attn_type)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):
@ -710,7 +750,7 @@ class TransformersForCausalLM(TransformersBase):
# Tell `TransformersBase.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes = ["lm_head."]
self.skip_prefixes.append("lm_head.")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size