mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 13:51:19 +08:00
[New Model]: Support GteNewModelForSequenceClassification (#23524)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
186aced5ff
commit
11a7fafaa8
@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
|||||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
|
||||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
|
||||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
|
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
|
||||||
@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
|||||||
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
|
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
Load the official original `mxbai-rerank-v2` by using the following command.
|
Load the official original `mxbai-rerank-v2` by using the following command.
|
||||||
|
|
||||||
|
|||||||
@ -456,11 +456,10 @@ class HfRunner:
|
|||||||
# output is final logits
|
# output is final logits
|
||||||
all_inputs = self.get_inputs(prompts)
|
all_inputs = self.get_inputs(prompts)
|
||||||
outputs = []
|
outputs = []
|
||||||
|
problem_type = getattr(self.config, "problem_type", "")
|
||||||
|
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
output = self.model(**self.wrap_device(inputs))
|
output = self.model(**self.wrap_device(inputs))
|
||||||
|
|
||||||
problem_type = getattr(self.config, "problem_type", "")
|
|
||||||
|
|
||||||
if problem_type == "regression":
|
if problem_type == "regression":
|
||||||
logits = output.logits[0].tolist()
|
logits = output.logits[0].tolist()
|
||||||
elif problem_type == "multi_label_classification":
|
elif problem_type == "multi_label_classification":
|
||||||
|
|||||||
@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner,
|
|||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||||
|
|
||||||
|
if model_info.hf_overrides is not None:
|
||||||
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
|
|||||||
@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||||
|
|
||||||
|
if model_info.hf_overrides is not None:
|
||||||
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||||
|
|
||||||
|
if model_info.hf_overrides is not None:
|
||||||
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
|
|||||||
@ -13,7 +13,14 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
|||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||||
architecture="GemmaForSequenceClassification"),
|
architecture="GemmaForSequenceClassification",
|
||||||
|
hf_overrides={
|
||||||
|
"architectures":
|
||||||
|
["GemmaForSequenceClassification"],
|
||||||
|
"classifier_from_token": ["Yes"],
|
||||||
|
"method":
|
||||||
|
"no_post_processing",
|
||||||
|
}),
|
||||||
]
|
]
|
||||||
|
|
||||||
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
||||||
@ -119,22 +126,9 @@ class GemmaMtebEncoder(VllmMtebEncoder):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
|
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||||
monkeypatch) -> None:
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
|
||||||
|
|
||||||
assert model_info.architecture == "GemmaForSequenceClassification"
|
|
||||||
|
|
||||||
vllm_extra_kwargs: dict[str, Any] = {
|
|
||||||
"hf_overrides": {
|
|
||||||
"architectures": ["GemmaForSequenceClassification"],
|
|
||||||
"classifier_from_token": ["Yes"],
|
|
||||||
"method": "no_post_processing",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mteb_test_rerank_models(GemmaRerankerHfRunner,
|
mteb_test_rerank_models(GemmaRerankerHfRunner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
model_info,
|
model_info,
|
||||||
vllm_extra_kwargs,
|
|
||||||
vllm_mteb_encoder=GemmaMtebEncoder)
|
vllm_mteb_encoder=GemmaMtebEncoder)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -33,12 +32,15 @@ MODELS = [
|
|||||||
########### NewModel
|
########### NewModel
|
||||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
|
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
|
||||||
architecture="GteNewModel",
|
architecture="GteNewModel",
|
||||||
|
hf_overrides={"architectures": ["GteNewModel"]},
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
|
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
|
||||||
architecture="GteNewModel",
|
architecture="GteNewModel",
|
||||||
|
hf_overrides={"architectures": ["GteNewModel"]},
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
|
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
|
||||||
architecture="GteNewModel",
|
architecture="GteNewModel",
|
||||||
|
hf_overrides={"architectures": ["GteNewModel"]},
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
########### Qwen2ForCausalLM
|
########### Qwen2ForCausalLM
|
||||||
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||||
@ -60,11 +62,16 @@ MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
# classifier_pooling: mean
|
|
||||||
CLSPoolingRerankModelInfo(
|
CLSPoolingRerankModelInfo(
|
||||||
|
# classifier_pooling: mean
|
||||||
"Alibaba-NLP/gte-reranker-modernbert-base",
|
"Alibaba-NLP/gte-reranker-modernbert-base",
|
||||||
architecture="ModernBertForSequenceClassification",
|
architecture="ModernBertForSequenceClassification",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
|
CLSPoolingRerankModelInfo(
|
||||||
|
"Alibaba-NLP/gte-multilingual-reranker-base",
|
||||||
|
architecture="GteNewForSequenceClassification",
|
||||||
|
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
|
||||||
|
enable_test=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
|
|||||||
check_transformers_version(model_info.name,
|
check_transformers_version(model_info.name,
|
||||||
max_transformers_version="4.53.2")
|
max_transformers_version="4.53.2")
|
||||||
|
|
||||||
vllm_extra_kwargs: dict[str, Any] = {}
|
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||||
if model_info.architecture == "GteNewModel":
|
|
||||||
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
|
|
||||||
|
|
||||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
|
|
||||||
vllm_extra_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", MODELS)
|
@pytest.mark.parametrize("model_info", MODELS)
|
||||||
@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
|
|||||||
check_transformers_version(model_info.name,
|
check_transformers_version(model_info.name,
|
||||||
max_transformers_version="4.53.2")
|
max_transformers_version="4.53.2")
|
||||||
|
|
||||||
vllm_extra_kwargs: dict[str, Any] = {}
|
|
||||||
if model_info.architecture == "GteNewModel":
|
|
||||||
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
|
|
||||||
|
|
||||||
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
|
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
|
||||||
example_prompts, vllm_extra_kwargs)
|
example_prompts)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||||
|
|||||||
@ -10,12 +10,20 @@ from tests.conftest import HfRunner
|
|||||||
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||||
from .mteb_utils import mteb_test_rerank_models
|
from .mteb_utils import mteb_test_rerank_models
|
||||||
|
|
||||||
|
mxbai_rerank_hf_overrides = {
|
||||||
|
"architectures": ["Qwen2ForSequenceClassification"],
|
||||||
|
"classifier_from_token": ["0", "1"],
|
||||||
|
"method": "from_2_way_softmax",
|
||||||
|
}
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
||||||
architecture="Qwen2ForSequenceClassification",
|
architecture="Qwen2ForSequenceClassification",
|
||||||
|
hf_overrides=mxbai_rerank_hf_overrides,
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
||||||
architecture="Qwen2ForSequenceClassification",
|
architecture="Qwen2ForSequenceClassification",
|
||||||
|
hf_overrides=mxbai_rerank_hf_overrides,
|
||||||
enable_test=False)
|
enable_test=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -71,13 +79,4 @@ class MxbaiRerankerHfRunner(HfRunner):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||||
vllm_extra_kwargs: dict[str, Any] = {}
|
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
|
||||||
if model_info.architecture == "Qwen2ForSequenceClassification":
|
|
||||||
vllm_extra_kwargs["hf_overrides"] = {
|
|
||||||
"architectures": ["Qwen2ForSequenceClassification"],
|
|
||||||
"classifier_from_token": ["0", "1"],
|
|
||||||
"method": "from_2_way_softmax",
|
|
||||||
}
|
|
||||||
|
|
||||||
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
|
|
||||||
vllm_extra_kwargs)
|
|
||||||
|
|||||||
@ -11,12 +11,20 @@ from tests.utils import multi_gpu_test
|
|||||||
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||||
from .mteb_utils import mteb_test_rerank_models
|
from .mteb_utils import mteb_test_rerank_models
|
||||||
|
|
||||||
|
qwen3_reranker_hf_overrides = {
|
||||||
|
"architectures": ["Qwen3ForSequenceClassification"],
|
||||||
|
"classifier_from_token": ["no", "yes"],
|
||||||
|
"is_original_qwen3_reranker": True,
|
||||||
|
}
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
||||||
architecture="Qwen3ForSequenceClassification",
|
architecture="Qwen3ForSequenceClassification",
|
||||||
|
hf_overrides=qwen3_reranker_hf_overrides,
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
||||||
architecture="Qwen3ForSequenceClassification",
|
architecture="Qwen3ForSequenceClassification",
|
||||||
|
hf_overrides=qwen3_reranker_hf_overrides,
|
||||||
enable_test=False)
|
enable_test=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -74,18 +82,7 @@ class Qwen3RerankerHfRunner(HfRunner):
|
|||||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||||
|
|
||||||
assert model_info.architecture == "Qwen3ForSequenceClassification"
|
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
|
||||||
|
|
||||||
vllm_extra_kwargs: dict[str, Any] = {
|
|
||||||
"hf_overrides": {
|
|
||||||
"architectures": ["Qwen3ForSequenceClassification"],
|
|
||||||
"classifier_from_token": ["no", "yes"],
|
|
||||||
"is_original_qwen3_reranker": True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
|
|
||||||
vllm_extra_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||||
@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner,
|
|||||||
assert model_info.architecture == "Qwen3ForSequenceClassification"
|
assert model_info.architecture == "Qwen3ForSequenceClassification"
|
||||||
|
|
||||||
vllm_extra_kwargs: dict[str, Any] = {
|
vllm_extra_kwargs: dict[str, Any] = {
|
||||||
"hf_overrides": {
|
|
||||||
"architectures": ["Qwen3ForSequenceClassification"],
|
|
||||||
"classifier_from_token": ["no", "yes"],
|
|
||||||
"is_original_qwen3_reranker": True,
|
|
||||||
},
|
|
||||||
"tensor_parallel_size": 2,
|
"tensor_parallel_size": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -365,6 +365,10 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
|||||||
|
|
||||||
# [Cross-encoder]
|
# [Cross-encoder]
|
||||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
||||||
|
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
|
||||||
|
trust_remote_code=True,
|
||||||
|
hf_overrides={
|
||||||
|
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
|
||||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, NamedTuple, Optional, Union
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -339,36 +340,43 @@ def softmax(data):
|
|||||||
return F.softmax(data, dim=-1)
|
return F.softmax(data, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class EmbedModelInfo(NamedTuple):
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
name: str
|
name: str
|
||||||
is_matryoshka: bool = False
|
|
||||||
matryoshka_dimensions: Optional[list[int]] = None
|
|
||||||
architecture: str = ""
|
architecture: str = ""
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
|
hf_overrides: Optional[dict[str, Any]] = None
|
||||||
default_pooling_type: str = ""
|
default_pooling_type: str = ""
|
||||||
enable_test: bool = True
|
enable_test: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbedModelInfo(ModelInfo):
|
||||||
|
is_matryoshka: bool = False
|
||||||
|
matryoshka_dimensions: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
|
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
|
||||||
default_pooling_type: str = "CLS"
|
default_pooling_type: str = "CLS"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
||||||
default_pooling_type: str = "LAST"
|
default_pooling_type: str = "LAST"
|
||||||
|
|
||||||
|
|
||||||
class RerankModelInfo(NamedTuple):
|
@dataclass
|
||||||
name: str
|
class RerankModelInfo(ModelInfo):
|
||||||
architecture: str = ""
|
pass
|
||||||
dtype: str = "auto"
|
|
||||||
default_pooling_type: str = ""
|
|
||||||
enable_test: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class CLSPoolingRerankModelInfo(RerankModelInfo):
|
class CLSPoolingRerankModelInfo(RerankModelInfo):
|
||||||
default_pooling_type: str = "CLS"
|
default_pooling_type: str = "CLS"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class LASTPoolingRerankModelInfo(RerankModelInfo):
|
class LASTPoolingRerankModelInfo(RerankModelInfo):
|
||||||
default_pooling_type: str = "LAST"
|
default_pooling_type: str = "LAST"
|
||||||
|
|
||||||
|
|||||||
@ -27,12 +27,15 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
|
maybe_prefix)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsQuant
|
from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
|
||||||
|
from .bert import BertPooler
|
||||||
|
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||||
from .interfaces_base import default_pooling_type
|
from .interfaces_base import default_pooling_type
|
||||||
|
|
||||||
|
|
||||||
@ -406,9 +409,14 @@ class BertWithRopeEncoder(nn.Module):
|
|||||||
class BertWithRope(nn.Module, SupportsQuant):
|
class BertWithRope(nn.Module, SupportsQuant):
|
||||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
add_pooling_layer: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.add_pooling_layer = add_pooling_layer
|
||||||
self.config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
self.embeddings = BertWithRopeEmbedding(self.config)
|
self.embeddings = BertWithRopeEmbedding(self.config)
|
||||||
self.encoder = BertWithRopeEncoder(
|
self.encoder = BertWithRopeEncoder(
|
||||||
@ -416,6 +424,7 @@ class BertWithRope(nn.Module, SupportsQuant):
|
|||||||
bias=getattr(self.config, "bias", True),
|
bias=getattr(self.config, "bias", True),
|
||||||
rotary_kwargs=self.config.rotary_kwargs,
|
rotary_kwargs=self.config.rotary_kwargs,
|
||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
|
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -448,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "pooler" in name:
|
if not self.add_pooling_layer and "pooler" in name:
|
||||||
continue
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
@ -508,8 +517,8 @@ class GteNewModel(BertWithRope):
|
|||||||
"attention.o_proj": "attn.out_proj",
|
"attention.o_proj": "attn.out_proj",
|
||||||
})
|
})
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
|
||||||
# GteNewModel only gate_up_proj does not have bias.
|
# GteNewModel only gate_up_proj does not have bias.
|
||||||
# Hack method learned from vllm/model_executor/models/glm.py
|
# Hack method learned from vllm/model_executor/models/glm.py
|
||||||
@ -614,3 +623,65 @@ class JinaRobertaModel(BertWithRope):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
weights = self.jina_merge_lora_weights(weights)
|
weights = self.jina_merge_lora_weights(weights)
|
||||||
return super().load_weights(weights)
|
return super().load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("CLS")
|
||||||
|
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||||
|
is_pooling_model = True
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.new = GteNewModel(vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
add_pooling_layer=True)
|
||||||
|
self.classifier = RowParallelLinear(config.hidden_size,
|
||||||
|
config.num_labels,
|
||||||
|
input_is_parallel=False,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "classifier"),
|
||||||
|
return_bias=False)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler({
|
||||||
|
"encode":
|
||||||
|
Pooler.for_encode(pooler_config),
|
||||||
|
"classify":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=self.new.pooler,
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
"score":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=self.new.pooler,
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
loaded_params = loader.load_weights(weights)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
return self.new(input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors)
|
||||||
|
|||||||
@ -406,6 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||||
"GteModel": SnowflakeGteNewModelConfig,
|
"GteModel": SnowflakeGteNewModelConfig,
|
||||||
"GteNewModel": GteNewModelConfig,
|
"GteNewModel": GteNewModelConfig,
|
||||||
|
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||||
"NomicBertModel": NomicBertModelConfig,
|
"NomicBertModel": NomicBertModelConfig,
|
||||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||||
|
|||||||
@ -191,12 +191,14 @@ _EMBEDDING_MODELS = {
|
|||||||
|
|
||||||
_CROSS_ENCODER_MODELS = {
|
_CROSS_ENCODER_MODELS = {
|
||||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||||
|
"GteNewForSequenceClassification": ("bert_with_rope",
|
||||||
|
"GteNewForSequenceClassification"),
|
||||||
|
"ModernBertForSequenceClassification": ("modernbert",
|
||||||
|
"ModernBertForSequenceClassification"),
|
||||||
"RobertaForSequenceClassification": ("roberta",
|
"RobertaForSequenceClassification": ("roberta",
|
||||||
"RobertaForSequenceClassification"),
|
"RobertaForSequenceClassification"),
|
||||||
"XLMRobertaForSequenceClassification": ("roberta",
|
"XLMRobertaForSequenceClassification": ("roberta",
|
||||||
"RobertaForSequenceClassification"),
|
"RobertaForSequenceClassification"),
|
||||||
"ModernBertForSequenceClassification": ("modernbert",
|
|
||||||
"ModernBertForSequenceClassification"),
|
|
||||||
# [Auto-converted (see adapters.py)]
|
# [Auto-converted (see adapters.py)]
|
||||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user