Add explicit pooling classes for the Transformers backend (#25322)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Harry Mellor 2025-09-30 23:07:06 +01:00 committed by yewentao256
parent 8328d39d40
commit b3e1846da6
7 changed files with 295 additions and 137 deletions

View File

@ -657,7 +657,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
}
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
"TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
}

View File

@ -9,9 +9,16 @@ from vllm.platforms import current_platform
from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test, prep_prompts
from .registry import HF_EXAMPLE_MODELS
from .utils import check_embeddings_close, check_logprobs_close
def get_model(arch: str) -> str:
model_info = HF_EXAMPLE_MODELS.get_hf_info(arch)
model_info.check_transformers_version(on_fail="skip")
return model_info.default
def check_implementation(
runner_ref: type[Union[HfRunner, VllmRunner]],
runner_test: type[VllmRunner],
@ -170,71 +177,47 @@ def test_embed_loading(vllm_runner, model):
@pytest.mark.parametrize(
"model",
[
# Encoder model
"BAAI/bge-base-en-v1.5",
])
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
pytest.skip("Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
"arch",
["TransformersEmbeddingModel", "TransformersForSequenceClassification"])
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
model = get_model(arch)
with vllm_runner(model, max_model_len=512,
model_impl="transformers") as vllm_model:
vllm_kwargs = dict(
max_model_len=None,
model_impl="transformers",
compilation_config=dict(cudagraph_capture_sizes=[8]),
)
hf_kwargs = dict()
if arch == "TransformersEmbeddingModel":
hf_kwargs["is_sentence_transformer"] = True
elif arch == "TransformersForSequenceClassification":
from transformers import AutoModelForSequenceClassification
hf_kwargs["auto_cls"] = AutoModelForSequenceClassification
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]
with (vllm_runner(model, **vllm_kwargs) as
vllm_model, hf_runner(model, **hf_kwargs) as hf_model):
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend()
vllm_outputs = vllm_model.embed(example_prompts)
with hf_runner(model, is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
if arch == "TransformersEmbeddingModel":
vllm_outputs = vllm_model.embed(example_prompts)
hf_outputs = hf_model.encode(example_prompts)
elif arch == "TransformersForSequenceClassification":
vllm_outputs = vllm_model.classify(example_prompts)
hf_outputs = hf_model.classify(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["float"])
def test_classify(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
import torch
from transformers import AutoModelForSequenceClassification
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
model_impl="transformers") as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend()
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)

View File

@ -19,6 +19,7 @@ import vllm.envs as envs
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -40,7 +41,6 @@ if TYPE_CHECKING:
import vllm.model_executor.models as me_models
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.scheduler import RunnerType
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.v1.sample.logits_processor import LogitsProcessor
else:
@ -52,13 +52,12 @@ else:
"vllm.model_executor.models")
LoadConfig = Any
ParallelConfig = Any
RunnerType = Any
QuantizationMethods = Any
LogitsProcessor = Any
logger = init_logger(__name__)
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
@ -668,8 +667,28 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
if getattr(self, "runner_type", self.runner) == "pooling":
return "TransformersModel"
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
if defaults := try_match_architecture_defaults(self.architectures[0]):
_, (runner, convert) = defaults
# Overwrite with user-specified values
if self.runner != "auto":
runner = self.runner
if self.convert not in {"auto", "none"}:
convert = self.convert
# Fall back to default values if still not set
if runner is None:
runner = "generate"
if convert in {None, "none"}:
convert = "embed"
# Resolve Transformers backend pooling classes
if runner == "pooling":
if convert == "embed":
return "TransformersEmbeddingModel"
if convert == "classify":
return "TransformersForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
@ -678,7 +697,9 @@ class ModelConfig:
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""
return self.architecture == self._get_transformers_backend_cls()
used_cls = self._model_info.architecture
transformers_backend_cls = self._get_transformers_backend_cls()
return used_cls == transformers_backend_cls
@property
def registry(self):

View File

@ -4,6 +4,7 @@
import ast
import inspect
import textwrap
from collections.abc import Iterable
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
@ -52,6 +53,18 @@ def get_field(cls: ConfigType, name: str) -> Field:
f"{cls.__name__}.{name} must have a default value or default factory.")
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
"""
for name in names:
if hasattr(object, name):
return getattr(object, name)
return default
def contains_object_print(text: str) -> bool:
"""
Check if the text looks like a printed Python object, e.g.

View File

@ -307,9 +307,10 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
}
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": ("transformers", "TransformersModel"),
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
}
# yapf: enable

View File

@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
@ -486,10 +487,13 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
names = ("embedding_size", "hidden_size")
embedding_dim = getattr_iter(self.text_config, names, None)
assert embedding_dim is not None
self.model.set_input_embeddings(
VocabParallelEmbedding(
self.text_config.vocab_size,
self.text_config.hidden_size,
embedding_dim=embedding_dim,
org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config,
))
@ -645,7 +649,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
attn_type=attn_type)
return attention_instances
def init_parameters(self, module: nn.Module):
def init_parameters(self,
module: nn.Module,
dtype: Optional[torch.dtype] = None):
"""
If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by:
@ -659,11 +665,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data,
dtype=self.model_config.dtype,
dtype=dtype or self.model_config.dtype,
device=self.device_config.device))
setattr(module, name, new_param)
for child in module.children():
self.init_parameters(child)
self.init_parameters(child, dtype)
def forward(
self,
@ -712,73 +718,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersModel(TransformersBase):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Handle BERT-like models
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Pooling adapters will be adjacent to `model`
"model.pooler": "pooler",
"model.score": "score",
# Classifier adapter's classifier layer is renamed to score
"model.classifier": "score",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# After creating a pooling model, `pooler` will be duplicated.
# The one inside `model` comes from the Transformers modelling code.
# The one after `model` is an adapter from vLLM.
# We want to use the adapter so we nullify the original pooler.
if getattr(self.model, "pooler", None) is not None:
self.skip_prefixes.append("pooler.")
self.model.pooler = torch.nn.Identity()
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
return super().create_attention_instances(attn_type)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):

View File

@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models for pooling tasks."""
from typing import Optional, Union
import torch
from transformers import AutoModelForSequenceClassification
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.sequence import IntermediateTensors
from .interfaces_base import VllmModelForPooling
from .transformers import TransformersBase, can_enable_torch_compile
from .utils import WeightsMapper
class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend([
"model.lm_head.", "model.predictions.", "model.qa_outputs.",
"model.embeddings_project.", "model.discriminator_predictions."
])
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
# roberta-like models an extra padding in positions.
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
# we should find a better way to handle this.
self.is_roberta = "roberta" in self.text_config.model_type
self.padding_idx = self.text_config.pad_token_id
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
return super().create_attention_instances(attn_type)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.is_roberta:
# RoBERTa-specific positions padding
positions += self.padding_idx + 1
return super().forward(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(TransformersPoolingBase):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
})
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification(TransformersPoolingBase):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# Certain information about the the model and classifier can only be
# inferred from the `ForSequenceClassification` class. Therefore, we
# instantiate it on the "meta" device to avoid allocating GPU memory.
with torch.device("meta"):
seq_cls_model = AutoModelForSequenceClassification.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# When used for sequence classification, some models have their
# pooling layers removed. Make sure this is reflected in vLLM.
for module in seq_cls_model.modules():
if hasattr(module, "pooler") and module.pooler is None:
self.model.pooler = None
break
if self.model.pooler is not None:
raise ValueError(
"Sequence classification models with pooling layers are not "
"supported yet in the Transformers backend.")
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
self.classifier = seq_cls_model.classifier
self.init_parameters(self.classifier,
dtype=self.model_config.head_dtype)
class ClassifierWithReshape(self.classifier.__class__):
"""CLSPool has already been applied in `pooling`.
Add dim to match expected input shape of `classifier.forward`."""
def forward(self, *args, **kwargs):
if len(args) > 0:
args = (args[0].unsqueeze(1), *args[1:])
return super().forward(*args, **kwargs)
self.classifier.__class__ = ClassifierWithReshape
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})