mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 01:35:47 +08:00
Add explicit pooling classes for the Transformers backend (#25322)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
8328d39d40
commit
b3e1846da6
@ -657,7 +657,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
|
||||
"TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||
}
|
||||
|
||||
@ -9,9 +9,16 @@ from vllm.platforms import current_platform
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..utils import multi_gpu_test, prep_prompts
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
from .utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
|
||||
def get_model(arch: str) -> str:
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(arch)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
return model_info.default
|
||||
|
||||
|
||||
def check_implementation(
|
||||
runner_ref: type[Union[HfRunner, VllmRunner]],
|
||||
runner_test: type[VllmRunner],
|
||||
@ -170,71 +177,47 @@ def test_embed_loading(vllm_runner, model):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# Encoder model
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
])
|
||||
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
if installed < required:
|
||||
pytest.skip("Encoder models with the Transformers backend require "
|
||||
f"transformers>={required}, but got {installed}")
|
||||
"arch",
|
||||
["TransformersEmbeddingModel", "TransformersForSequenceClassification"])
|
||||
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
|
||||
model = get_model(arch)
|
||||
|
||||
with vllm_runner(model, max_model_len=512,
|
||||
model_impl="transformers") as vllm_model:
|
||||
vllm_kwargs = dict(
|
||||
max_model_len=None,
|
||||
model_impl="transformers",
|
||||
compilation_config=dict(cudagraph_capture_sizes=[8]),
|
||||
)
|
||||
|
||||
hf_kwargs = dict()
|
||||
if arch == "TransformersEmbeddingModel":
|
||||
hf_kwargs["is_sentence_transformer"] = True
|
||||
elif arch == "TransformersForSequenceClassification":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
hf_kwargs["auto_cls"] = AutoModelForSequenceClassification
|
||||
|
||||
# The example_prompts has ending "\n", for example:
|
||||
# "Write a short story about a robot that dreams for the first time.\n"
|
||||
# sentence_transformers will strip the input texts, see:
|
||||
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
|
||||
# This makes the input_ids different between hf_model and vllm_model.
|
||||
# So we need to strip the input texts to avoid test failing.
|
||||
example_prompts = [str(s).strip() for s in example_prompts]
|
||||
|
||||
with (vllm_runner(model, **vllm_kwargs) as
|
||||
vllm_model, hf_runner(model, **hf_kwargs) as hf_model):
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
if arch == "TransformersEmbeddingModel":
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
with hf_runner(model, is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
elif arch == "TransformersForSequenceClassification":
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_classify(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
model_impl="transformers") as vllm_model:
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output)
|
||||
vllm_output = torch.tensor(vllm_output)
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
@ -19,6 +19,7 @@ import vllm.envs as envs
|
||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -40,7 +41,6 @@ if TYPE_CHECKING:
|
||||
import vllm.model_executor.models as me_models
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
else:
|
||||
@ -52,13 +52,12 @@ else:
|
||||
"vllm.model_executor.models")
|
||||
LoadConfig = Any
|
||||
ParallelConfig = Any
|
||||
RunnerType = Any
|
||||
QuantizationMethods = Any
|
||||
LogitsProcessor = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
||||
RunnerOption = Literal["auto", RunnerType]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
@ -668,8 +667,28 @@ class ModelConfig:
|
||||
def _get_transformers_backend_cls(self) -> str:
|
||||
"""Determine which Transformers backend class will be used if
|
||||
`model_impl` is set to `transformers` or `auto`."""
|
||||
if getattr(self, "runner_type", self.runner) == "pooling":
|
||||
return "TransformersModel"
|
||||
# Check if the architecture we're wrapping has defaults
|
||||
runner = None
|
||||
convert = None
|
||||
if defaults := try_match_architecture_defaults(self.architectures[0]):
|
||||
_, (runner, convert) = defaults
|
||||
# Overwrite with user-specified values
|
||||
if self.runner != "auto":
|
||||
runner = self.runner
|
||||
if self.convert not in {"auto", "none"}:
|
||||
convert = self.convert
|
||||
# Fall back to default values if still not set
|
||||
if runner is None:
|
||||
runner = "generate"
|
||||
if convert in {None, "none"}:
|
||||
convert = "embed"
|
||||
# Resolve Transformers backend pooling classes
|
||||
if runner == "pooling":
|
||||
if convert == "embed":
|
||||
return "TransformersEmbeddingModel"
|
||||
if convert == "classify":
|
||||
return "TransformersForSequenceClassification"
|
||||
# Resolve Transformers backend generate classes
|
||||
if self.hf_config != self.hf_text_config:
|
||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||
# probably a composite config, i.e. multimodal
|
||||
@ -678,7 +697,9 @@ class ModelConfig:
|
||||
|
||||
def using_transformers_backend(self) -> bool:
|
||||
"""Check if the model is using the Transformers backend class."""
|
||||
return self.architecture == self._get_transformers_backend_cls()
|
||||
used_cls = self._model_info.architecture
|
||||
transformers_backend_cls = self._get_transformers_backend_cls()
|
||||
return used_cls == transformers_backend_cls
|
||||
|
||||
@property
|
||||
def registry(self):
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
@ -52,6 +53,18 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
"""
|
||||
for name in names:
|
||||
if hasattr(object, name):
|
||||
return getattr(object, name)
|
||||
return default
|
||||
|
||||
|
||||
def contains_object_print(text: str) -> bool:
|
||||
"""
|
||||
Check if the text looks like a printed Python object, e.g.
|
||||
|
||||
@ -307,7 +307,8 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersModel": ("transformers", "TransformersModel"),
|
||||
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
|
||||
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
|
||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, VllmConfig)
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.logger import init_logger
|
||||
@ -486,10 +487,13 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
|
||||
# Input embeddings
|
||||
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
||||
names = ("embedding_size", "hidden_size")
|
||||
embedding_dim = getattr_iter(self.text_config, names, None)
|
||||
assert embedding_dim is not None
|
||||
self.model.set_input_embeddings(
|
||||
VocabParallelEmbedding(
|
||||
self.text_config.vocab_size,
|
||||
self.text_config.hidden_size,
|
||||
embedding_dim=embedding_dim,
|
||||
org_num_embeddings=self.text_config.vocab_size,
|
||||
quant_config=self.quant_config,
|
||||
))
|
||||
@ -645,7 +649,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
attn_type=attn_type)
|
||||
return attention_instances
|
||||
|
||||
def init_parameters(self, module: nn.Module):
|
||||
def init_parameters(self,
|
||||
module: nn.Module,
|
||||
dtype: Optional[torch.dtype] = None):
|
||||
"""
|
||||
If a `parameter` is on the `meta` device, then its parent
|
||||
`module` is the original module created by:
|
||||
@ -659,11 +665,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
if param.device == torch.device("meta"):
|
||||
new_param = nn.Parameter(
|
||||
torch.empty_like(param.data,
|
||||
dtype=self.model_config.dtype,
|
||||
dtype=dtype or self.model_config.dtype,
|
||||
device=self.device_config.device))
|
||||
setattr(module, name, new_param)
|
||||
for child in module.children():
|
||||
self.init_parameters(child)
|
||||
self.init_parameters(child, dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -712,73 +718,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersModel(TransformersBase):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# Handle BERT-like models
|
||||
"bert": "model",
|
||||
# Add `model.` prefix for base model checkpoints
|
||||
"": "model.",
|
||||
# Remove `model.` prefix if it was already there
|
||||
"model.model.": "model.",
|
||||
# Pooling adapters will be adjacent to `model`
|
||||
"model.pooler": "pooler",
|
||||
"model.score": "score",
|
||||
# Classifier adapter's classifier layer is renamed to score
|
||||
"model.classifier": "score",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
# Replace legacy suffixes used for norms
|
||||
".gamma": ".weight",
|
||||
".beta": ".bias",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# After creating a pooling model, `pooler` will be duplicated.
|
||||
# The one inside `model` comes from the Transformers modelling code.
|
||||
# The one after `model` is an adapter from vLLM.
|
||||
# We want to use the adapter so we nullify the original pooler.
|
||||
if getattr(self.model, "pooler", None) is not None:
|
||||
self.skip_prefixes.append("pooler.")
|
||||
self.model.pooler = torch.nn.Identity()
|
||||
|
||||
# Some encoder models have the position_ids buffer in the checkpoint.
|
||||
# vLLM will always pass position_ids as an argument, so we skip loading
|
||||
# the buffer if it exists
|
||||
self.skip_substrs.append("position_ids")
|
||||
|
||||
# Some encoder models have the bias of the final classifier layer
|
||||
# in the checkpoint. vLLM does not use this bias, so we skip loading
|
||||
# it if it exists
|
||||
self.skip_substrs.append("score.bias")
|
||||
|
||||
def create_attention_instances(
|
||||
self, attn_type: AttentionType = AttentionType.DECODER):
|
||||
# TODO(hmellor): Better way to detect encoder models
|
||||
# In encoder models, the attention layers will have `is_causal=False`
|
||||
is_encoder = lambda m: not getattr(m, "is_causal", True)
|
||||
# vLLM does not support encoder-decoder models, so if any encoder layer
|
||||
# is found, we assume the whole model is an encoder model
|
||||
if any(is_encoder(m) for m in self.model.modules()):
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
# Check minimum transformers version for encoder models support
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
if installed < required:
|
||||
raise ValueError(
|
||||
"Encoder models with the Transformers backend require "
|
||||
f"transformers>={required}, but got {installed}")
|
||||
|
||||
return super().create_attention_instances(attn_type)
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersForCausalLM(TransformersBase):
|
||||
|
||||
|
||||
200
vllm/model_executor/models/transformers_pooling.py
Normal file
200
vllm/model_executor/models/transformers_pooling.py
Normal file
@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2024 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Wrapper around `transformers` models for pooling tasks."""
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces_base import VllmModelForPooling
|
||||
from .transformers import TransformersBase, can_enable_torch_compile
|
||||
from .utils import WeightsMapper
|
||||
|
||||
|
||||
class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
# These are applied in order, so the order matters!
|
||||
orig_to_new_prefix={
|
||||
# Handle BERT-like models
|
||||
"roberta": "model",
|
||||
"bert": "model",
|
||||
# Add `model.` prefix for base model checkpoints
|
||||
"": "model.",
|
||||
# Remove `model.` prefix if it was already there
|
||||
"model.model.": "model.",
|
||||
# Classifier/scoring heads will be adjacent to `model`
|
||||
"model.score": "classifier",
|
||||
"model.classifier": "classifier",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
# Replace legacy suffixes used for norms
|
||||
".gamma": ".weight",
|
||||
".beta": ".bias",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# Skip unsupported/unwanted output embeddings layers
|
||||
self.skip_prefixes.extend([
|
||||
"model.lm_head.", "model.predictions.", "model.qa_outputs.",
|
||||
"model.embeddings_project.", "model.discriminator_predictions."
|
||||
])
|
||||
|
||||
# Some encoder models have the position_ids buffer in the checkpoint.
|
||||
# vLLM will always pass position_ids as an argument, so we skip loading
|
||||
# the buffer if it exists
|
||||
self.skip_substrs.append("position_ids")
|
||||
|
||||
# Some encoder models have the bias of the final classifier layer
|
||||
# in the checkpoint. vLLM does not use this bias, so we skip loading
|
||||
# it if it exists
|
||||
self.skip_substrs.append("score.bias")
|
||||
|
||||
# roberta-like models an extra padding in positions.
|
||||
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
|
||||
# we should find a better way to handle this.
|
||||
self.is_roberta = "roberta" in self.text_config.model_type
|
||||
self.padding_idx = self.text_config.pad_token_id
|
||||
|
||||
def create_attention_instances(
|
||||
self, attn_type: AttentionType = AttentionType.DECODER):
|
||||
# TODO(hmellor): Better way to detect encoder models
|
||||
# In encoder models, the attention layers will have `is_causal=False`
|
||||
is_encoder = lambda m: not getattr(m, "is_causal", True)
|
||||
# vLLM does not support encoder-decoder models, so if any encoder layer
|
||||
# is found, we assume the whole model is an encoder model
|
||||
if any(is_encoder(m) for m in self.model.modules()):
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
# Check minimum transformers version for encoder models support
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
if installed < required:
|
||||
raise ValueError(
|
||||
"Encoder models with the Transformers backend require "
|
||||
f"transformers>={required}, but got {installed}")
|
||||
|
||||
return super().create_attention_instances(attn_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if self.is_roberta:
|
||||
# RoBERTa-specific positions padding
|
||||
positions += self.padding_idx + 1
|
||||
return super().forward(input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersEmbeddingModel(TransformersPoolingBase):
|
||||
default_pooling_type = "CLS"
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
})
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersForSequenceClassification(TransformersPoolingBase):
|
||||
default_pooling_type = "CLS"
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
# Certain information about the the model and classifier can only be
|
||||
# inferred from the `ForSequenceClassification` class. Therefore, we
|
||||
# instantiate it on the "meta" device to avoid allocating GPU memory.
|
||||
with torch.device("meta"):
|
||||
seq_cls_model = AutoModelForSequenceClassification.from_config(
|
||||
self.config,
|
||||
torch_dtype=self.model_config.dtype,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# When used for sequence classification, some models have their
|
||||
# pooling layers removed. Make sure this is reflected in vLLM.
|
||||
for module in seq_cls_model.modules():
|
||||
if hasattr(module, "pooler") and module.pooler is None:
|
||||
self.model.pooler = None
|
||||
break
|
||||
if self.model.pooler is not None:
|
||||
raise ValueError(
|
||||
"Sequence classification models with pooling layers are not "
|
||||
"supported yet in the Transformers backend.")
|
||||
|
||||
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
|
||||
self.classifier = seq_cls_model.classifier
|
||||
self.init_parameters(self.classifier,
|
||||
dtype=self.model_config.head_dtype)
|
||||
|
||||
class ClassifierWithReshape(self.classifier.__class__):
|
||||
"""CLSPool has already been applied in `pooling`.
|
||||
Add dim to match expected input shape of `classifier.forward`."""
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
args = (args[0].unsqueeze(1), *args[1:])
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
self.classifier.__class__ = ClassifierWithReshape
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
Loading…
x
Reference in New Issue
Block a user