[New Model]: Support GteNewModelForSequenceClassification (#23524)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-08-28 15:36:42 +08:00 committed by GitHub
parent 186aced5ff
commit 11a7fafaa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 157 additions and 76 deletions

View File

@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|--------------|--------|-------------------|----------------------|---------------------------|---------------------| |--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
``` ```
!!! note
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
!!! note !!! note
Load the official original `mxbai-rerank-v2` by using the following command. Load the official original `mxbai-rerank-v2` by using the following command.

View File

@ -456,11 +456,10 @@ class HfRunner:
# output is final logits # output is final logits
all_inputs = self.get_inputs(prompts) all_inputs = self.get_inputs(prompts)
outputs = [] outputs = []
problem_type = getattr(self.config, "problem_type", "")
for inputs in all_inputs: for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs)) output = self.model(**self.wrap_device(inputs))
problem_type = getattr(self.config, "problem_type", "")
if problem_type == "regression": if problem_type == "regression":
logits = output.logits[0].tolist() logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification": elif problem_type == "multi_label_classification":

View File

@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,

View File

@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,
@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,

View File

@ -13,7 +13,14 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS = [ RERANK_MODELS = [
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"), architecture="GemmaForSequenceClassification",
hf_overrides={
"architectures":
["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method":
"no_post_processing",
}),
] ]
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
@ -119,22 +126,9 @@ class GemmaMtebEncoder(VllmMtebEncoder):
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
monkeypatch) -> None:
monkeypatch.setenv("VLLM_USE_V1", "0")
assert model_info.architecture == "GemmaForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
}
}
mteb_test_rerank_models(GemmaRerankerHfRunner, mteb_test_rerank_models(GemmaRerankerHfRunner,
vllm_runner, vllm_runner,
model_info, model_info,
vllm_extra_kwargs,
vllm_mteb_encoder=GemmaMtebEncoder) vllm_mteb_encoder=GemmaMtebEncoder)

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest import pytest
@ -33,12 +32,15 @@ MODELS = [
########### NewModel ########### NewModel
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel", architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True), enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel", architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True), enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel", architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True), enable_test=True),
########### Qwen2ForCausalLM ########### Qwen2ForCausalLM
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
@ -60,11 +62,16 @@ MODELS = [
] ]
RERANK_MODELS = [ RERANK_MODELS = [
# classifier_pooling: mean
CLSPoolingRerankModelInfo( CLSPoolingRerankModelInfo(
# classifier_pooling: mean
"Alibaba-NLP/gte-reranker-modernbert-base", "Alibaba-NLP/gte-reranker-modernbert-base",
architecture="ModernBertForSequenceClassification", architecture="ModernBertForSequenceClassification",
enable_test=True), enable_test=True),
CLSPoolingRerankModelInfo(
"Alibaba-NLP/gte-multilingual-reranker-base",
architecture="GteNewForSequenceClassification",
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
enable_test=True),
] ]
@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
check_transformers_version(model_info.name, check_transformers_version(model_info.name,
max_transformers_version="4.53.2") max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {} mteb_test_embed_models(hf_runner, vllm_runner, model_info)
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
check_transformers_version(model_info.name, check_transformers_version(model_info.name,
max_transformers_version="4.53.2") max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
correctness_test_embed_models(hf_runner, vllm_runner, model_info, correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs) example_prompts)
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)

View File

@ -10,12 +10,20 @@ from tests.conftest import HfRunner
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models from .mteb_utils import mteb_test_rerank_models
mxbai_rerank_hf_overrides = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}
RERANK_MODELS = [ RERANK_MODELS = [
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification", architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=True), enable_test=True),
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification", architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=False) enable_test=False)
] ]
@ -71,13 +79,4 @@ class MxbaiRerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
vllm_extra_kwargs: dict[str, Any] = {} mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
if model_info.architecture == "Qwen2ForSequenceClassification":
vllm_extra_kwargs["hf_overrides"] = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)

View File

@ -11,12 +11,20 @@ from tests.utils import multi_gpu_test
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models from .mteb_utils import mteb_test_rerank_models
qwen3_reranker_hf_overrides = {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
RERANK_MODELS = [ RERANK_MODELS = [
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
architecture="Qwen3ForSequenceClassification", architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=True), enable_test=True),
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
architecture="Qwen3ForSequenceClassification", architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=False) enable_test=False)
] ]
@ -74,18 +82,7 @@ class Qwen3RerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
assert model_info.architecture == "Qwen3ForSequenceClassification" mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
}
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner,
assert model_info.architecture == "Qwen3ForSequenceClassification" assert model_info.architecture == "Qwen3ForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = { vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"tensor_parallel_size": 2, "tensor_parallel_size": 2,
} }

View File

@ -365,6 +365,10 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Cross-encoder] # [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True,
hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501

View File

@ -3,7 +3,8 @@
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Union from dataclasses import dataclass
from typing import Any, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -339,36 +340,43 @@ def softmax(data):
return F.softmax(data, dim=-1) return F.softmax(data, dim=-1)
class EmbedModelInfo(NamedTuple): @dataclass
class ModelInfo:
name: str name: str
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = "" architecture: str = ""
dtype: str = "auto" dtype: str = "auto"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = "" default_pooling_type: str = ""
enable_test: bool = True enable_test: bool = True
@dataclass
class EmbedModelInfo(ModelInfo):
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
@dataclass
class CLSPoolingEmbedModelInfo(EmbedModelInfo): class CLSPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "CLS" default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingEmbedModelInfo(EmbedModelInfo): class LASTPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "LAST" default_pooling_type: str = "LAST"
class RerankModelInfo(NamedTuple): @dataclass
name: str class RerankModelInfo(ModelInfo):
architecture: str = "" pass
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class CLSPoolingRerankModelInfo(RerankModelInfo): class CLSPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "CLS" default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingRerankModelInfo(RerankModelInfo): class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST" default_pooling_type: str = "LAST"

View File

@ -27,12 +27,15 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
from .bert import BertPooler
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
@ -406,9 +409,14 @@ class BertWithRopeEncoder(nn.Module):
class BertWithRope(nn.Module, SupportsQuant): class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
add_pooling_layer: bool = False):
super().__init__() super().__init__()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.add_pooling_layer = add_pooling_layer
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.embeddings = BertWithRopeEmbedding(self.config) self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder( self.encoder = BertWithRopeEncoder(
@ -416,6 +424,7 @@ class BertWithRope(nn.Module, SupportsQuant):
bias=getattr(self.config, "bias", True), bias=getattr(self.config, "bias", True),
rotary_kwargs=self.config.rotary_kwargs, rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(self.config) if add_pooling_layer else None
def forward( def forward(
self, self,
@ -448,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "pooler" in name: if not self.add_pooling_layer and "pooler" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
@ -508,8 +517,8 @@ class GteNewModel(BertWithRope):
"attention.o_proj": "attn.out_proj", "attention.o_proj": "attn.out_proj",
}) })
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# GteNewModel only gate_up_proj does not have bias. # GteNewModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py # Hack method learned from vllm/model_executor/models/glm.py
@ -614,3 +623,65 @@ class JinaRobertaModel(BertWithRope):
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
weights = self.jina_merge_lora_weights(weights) weights = self.jina_merge_lora_weights(weights)
return super().load_weights(weights) return super().load_weights(weights)
@default_pooling_type("CLS")
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.new = GteNewModel(vllm_config=vllm_config,
prefix=prefix,
add_pooling_layer=True)
self.classifier = RowParallelLinear(config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "classifier"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.new(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

View File

@ -406,6 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig, "GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig, "GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"NomicBertModel": NomicBertModelConfig, "NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,

View File

@ -191,12 +191,14 @@ _EMBEDDING_MODELS = {
_CROSS_ENCODER_MODELS = { _CROSS_ENCODER_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"), "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"GteNewForSequenceClassification": ("bert_with_rope",
"GteNewForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"), "RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ("roberta", "XLMRobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"), "RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
# [Auto-converted (see adapters.py)] # [Auto-converted (see adapters.py)]
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
} }