Add explicit pooling classes for the Transformers backend (#25322)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Harry Mellor 2025-09-30 23:07:06 +01:00 committed by yewentao256
parent 8328d39d40
commit b3e1846da6
7 changed files with 295 additions and 137 deletions

View File

@ -657,7 +657,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"), "TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
} }

View File

@ -9,9 +9,16 @@ from vllm.platforms import current_platform
from ..conftest import HfRunner, VllmRunner from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test, prep_prompts from ..utils import multi_gpu_test, prep_prompts
from .registry import HF_EXAMPLE_MODELS
from .utils import check_embeddings_close, check_logprobs_close from .utils import check_embeddings_close, check_logprobs_close
def get_model(arch: str) -> str:
model_info = HF_EXAMPLE_MODELS.get_hf_info(arch)
model_info.check_transformers_version(on_fail="skip")
return model_info.default
def check_implementation( def check_implementation(
runner_ref: type[Union[HfRunner, VllmRunner]], runner_ref: type[Union[HfRunner, VllmRunner]],
runner_test: type[VllmRunner], runner_test: type[VllmRunner],
@ -170,71 +177,47 @@ def test_embed_loading(vllm_runner, model):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "arch",
[ ["TransformersEmbeddingModel", "TransformersForSequenceClassification"])
# Encoder model def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
"BAAI/bge-base-en-v1.5", model = get_model(arch)
])
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
pytest.skip("Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
with vllm_runner(model, max_model_len=512, vllm_kwargs = dict(
model_impl="transformers") as vllm_model: max_model_len=None,
model_impl="transformers",
compilation_config=dict(cudagraph_capture_sizes=[8]),
)
hf_kwargs = dict()
if arch == "TransformersEmbeddingModel":
hf_kwargs["is_sentence_transformer"] = True
elif arch == "TransformersForSequenceClassification":
from transformers import AutoModelForSequenceClassification
hf_kwargs["auto_cls"] = AutoModelForSequenceClassification
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]
with (vllm_runner(model, **vllm_kwargs) as
vllm_model, hf_runner(model, **hf_kwargs) as hf_model):
model_config = vllm_model.llm.llm_engine.model_config model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend() assert model_config.using_transformers_backend()
vllm_outputs = vllm_model.embed(example_prompts) if arch == "TransformersEmbeddingModel":
vllm_outputs = vllm_model.embed(example_prompts)
with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts)
hf_outputs = hf_model.encode(example_prompts) elif arch == "TransformersForSequenceClassification":
vllm_outputs = vllm_model.classify(example_prompts)
hf_outputs = hf_model.classify(example_prompts)
check_embeddings_close( check_embeddings_close(
embeddings_0_lst=hf_outputs, embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs, embeddings_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
tol=1e-2,
) )
@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["float"])
def test_classify(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
import torch
from transformers import AutoModelForSequenceClassification
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
model_impl="transformers") as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend()
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)

View File

@ -19,6 +19,7 @@ import vllm.envs as envs
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig) MultiModalConfig)
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config from vllm.config.utils import assert_hashable, config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -40,7 +41,6 @@ if TYPE_CHECKING:
import vllm.model_executor.models as me_models import vllm.model_executor.models as me_models
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.config.scheduler import RunnerType
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
else: else:
@ -52,13 +52,12 @@ else:
"vllm.model_executor.models") "vllm.model_executor.models")
LoadConfig = Any LoadConfig = Any
ParallelConfig = Any ParallelConfig = Any
RunnerType = Any
QuantizationMethods = Any QuantizationMethods = Any
LogitsProcessor = Any LogitsProcessor = Any
logger = init_logger(__name__) logger = init_logger(__name__)
RunnerOption = Literal["auto", "generate", "pooling", "draft"] RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify", "reward"] ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType] ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
@ -668,8 +667,28 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str: def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if """Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`.""" `model_impl` is set to `transformers` or `auto`."""
if getattr(self, "runner_type", self.runner) == "pooling": # Check if the architecture we're wrapping has defaults
return "TransformersModel" runner = None
convert = None
if defaults := try_match_architecture_defaults(self.architectures[0]):
_, (runner, convert) = defaults
# Overwrite with user-specified values
if self.runner != "auto":
runner = self.runner
if self.convert not in {"auto", "none"}:
convert = self.convert
# Fall back to default values if still not set
if runner is None:
runner = "generate"
if convert in {None, "none"}:
convert = "embed"
# Resolve Transformers backend pooling classes
if runner == "pooling":
if convert == "embed":
return "TransformersEmbeddingModel"
if convert == "classify":
return "TransformersForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config: if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is # If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal # probably a composite config, i.e. multimodal
@ -678,7 +697,9 @@ class ModelConfig:
def using_transformers_backend(self) -> bool: def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class.""" """Check if the model is using the Transformers backend class."""
return self.architecture == self._get_transformers_backend_cls() used_cls = self._model_info.architecture
transformers_backend_cls = self._get_transformers_backend_cls()
return used_cls == transformers_backend_cls
@property @property
def registry(self): def registry(self):

View File

@ -4,6 +4,7 @@
import ast import ast
import inspect import inspect
import textwrap import textwrap
from collections.abc import Iterable
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
@ -52,6 +53,18 @@ def get_field(cls: ConfigType, name: str) -> Field:
f"{cls.__name__}.{name} must have a default value or default factory.") f"{cls.__name__}.{name} must have a default value or default factory.")
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
"""
for name in names:
if hasattr(object, name):
return getattr(object, name)
return default
def contains_object_print(text: str) -> bool: def contains_object_print(text: str) -> bool:
""" """
Check if the text looks like a printed Python object, e.g. Check if the text looks like a printed Python object, e.g.

View File

@ -307,9 +307,10 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": ("transformers", "TransformersModel"), "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
} }
# yapf: enable # yapf: enable

View File

@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig) ParallelConfig, VllmConfig)
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger from vllm.logger import init_logger
@ -486,10 +487,13 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Input embeddings # Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
names = ("embedding_size", "hidden_size")
embedding_dim = getattr_iter(self.text_config, names, None)
assert embedding_dim is not None
self.model.set_input_embeddings( self.model.set_input_embeddings(
VocabParallelEmbedding( VocabParallelEmbedding(
self.text_config.vocab_size, self.text_config.vocab_size,
self.text_config.hidden_size, embedding_dim=embedding_dim,
org_num_embeddings=self.text_config.vocab_size, org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config, quant_config=self.quant_config,
)) ))
@ -645,7 +649,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
attn_type=attn_type) attn_type=attn_type)
return attention_instances return attention_instances
def init_parameters(self, module: nn.Module): def init_parameters(self,
module: nn.Module,
dtype: Optional[torch.dtype] = None):
""" """
If a `parameter` is on the `meta` device, then its parent If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by: `module` is the original module created by:
@ -659,11 +665,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
if param.device == torch.device("meta"): if param.device == torch.device("meta"):
new_param = nn.Parameter( new_param = nn.Parameter(
torch.empty_like(param.data, torch.empty_like(param.data,
dtype=self.model_config.dtype, dtype=dtype or self.model_config.dtype,
device=self.device_config.device)) device=self.device_config.device))
setattr(module, name, new_param) setattr(module, name, new_param)
for child in module.children(): for child in module.children():
self.init_parameters(child) self.init_parameters(child, dtype)
def forward( def forward(
self, self,
@ -712,73 +718,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersModel(TransformersBase):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Handle BERT-like models
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Pooling adapters will be adjacent to `model`
"model.pooler": "pooler",
"model.score": "score",
# Classifier adapter's classifier layer is renamed to score
"model.classifier": "score",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# After creating a pooling model, `pooler` will be duplicated.
# The one inside `model` comes from the Transformers modelling code.
# The one after `model` is an adapter from vLLM.
# We want to use the adapter so we nullify the original pooler.
if getattr(self.model, "pooler", None) is not None:
self.skip_prefixes.append("pooler.")
self.model.pooler = torch.nn.Identity()
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
return super().create_attention_instances(attn_type)
@support_torch_compile(enable_if=can_enable_torch_compile) @support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase): class TransformersForCausalLM(TransformersBase):

View File

@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models for pooling tasks."""
from typing import Optional, Union
import torch
from transformers import AutoModelForSequenceClassification
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.sequence import IntermediateTensors
from .interfaces_base import VllmModelForPooling
from .transformers import TransformersBase, can_enable_torch_compile
from .utils import WeightsMapper
class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend([
"model.lm_head.", "model.predictions.", "model.qa_outputs.",
"model.embeddings_project.", "model.discriminator_predictions."
])
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
# roberta-like models an extra padding in positions.
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
# we should find a better way to handle this.
self.is_roberta = "roberta" in self.text_config.model_type
self.padding_idx = self.text_config.pad_token_id
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
return super().create_attention_instances(attn_type)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.is_roberta:
# RoBERTa-specific positions padding
positions += self.padding_idx + 1
return super().forward(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(TransformersPoolingBase):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
})
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification(TransformersPoolingBase):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# Certain information about the the model and classifier can only be
# inferred from the `ForSequenceClassification` class. Therefore, we
# instantiate it on the "meta" device to avoid allocating GPU memory.
with torch.device("meta"):
seq_cls_model = AutoModelForSequenceClassification.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# When used for sequence classification, some models have their
# pooling layers removed. Make sure this is reflected in vLLM.
for module in seq_cls_model.modules():
if hasattr(module, "pooler") and module.pooler is None:
self.model.pooler = None
break
if self.model.pooler is not None:
raise ValueError(
"Sequence classification models with pooling layers are not "
"supported yet in the Transformers backend.")
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
self.classifier = seq_cls_model.classifier
self.init_parameters(self.classifier,
dtype=self.model_config.head_dtype)
class ClassifierWithReshape(self.classifier.__class__):
"""CLSPool has already been applied in `pooling`.
Add dim to match expected input shape of `classifier.forward`."""
def forward(self, *args, **kwargs):
if len(args) > 0:
args = (args[0].unsqueeze(1), *args[1:])
return super().forward(*args, **kwargs)
self.classifier.__class__ = ClassifierWithReshape
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})