[Model][3/N] Automatic conversion of CrossEncoding model (#20168)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-07-04 20:47:39 +08:00 committed by GitHub
parent 9e5452ee34
commit 2e26f9156a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 234 additions and 133 deletions

View File

@ -478,18 +478,27 @@ If your model is not in the above list, we will try to automatically convert the
Specified using `--task score`. Specified using `--task score`.
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) | | Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------| |---------------------------------------|-------------------|--------------------------------------------------------------------------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
!!! note
Load the official original `mxbai-rerank-v2` by using the following command.
```bash
vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}'
```
!!! note !!! note
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>. Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.
```bash ```bash
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
``` ```
[](){ #supported-mm-models } [](){ #supported-mm-models }
## List of Multimodal Language Models ## List of Multimodal Language Models

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from typing import Optional
import pytest import pytest
@ -74,6 +75,13 @@ def test_models(
vllm_extra_kwargs["override_pooler_config"] = \ vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN", normalize=False) PoolerConfig(pooling_type="MEAN", normalize=False)
max_model_len: Optional[int] = 512
if model in [
"sentence-transformers/all-MiniLM-L12-v2",
"sentence-transformers/stsb-roberta-base-v2"
]:
max_model_len = None
# The example_prompts has ending "\n", for example: # The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n" # "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see: # sentence_transformers will strip the input texts, see:
@ -87,7 +95,7 @@ def test_models(
with vllm_runner(model, with vllm_runner(model,
task="embed", task="embed",
max_model_len=512, max_model_len=max_model_len,
**vllm_extra_kwargs) as vllm_model: **vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.embed(example_prompts) vllm_outputs = vllm_model.embed(example_prompts)

View File

@ -56,10 +56,16 @@ MODELS = [
enable_test=False), enable_test=False),
] ]
V1FlashAttentionImpNotSupported = [
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-modernbert-base"
]
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner, def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo,
model_info: EmbedModelInfo) -> None: monkeypatch) -> None:
if model_info.name in V1FlashAttentionImpNotSupported:
monkeypatch.setenv("VLLM_USE_V1", "0")
vllm_extra_kwargs: dict[str, Any] = {} vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel": if model_info.architecture == "GteNewModel":
@ -71,8 +77,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_correctness(hf_runner, vllm_runner, def test_embed_models_correctness(hf_runner, vllm_runner,
model_info: EmbedModelInfo, model_info: EmbedModelInfo, example_prompts,
example_prompts) -> None: monkeypatch) -> None:
if model_info.name in V1FlashAttentionImpNotSupported:
monkeypatch.setenv("VLLM_USE_V1", "0")
vllm_extra_kwargs: dict[str, Any] = {} vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel": if model_info.architecture == "GteNewModel":

View File

@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest
import torch
from tests.conftest import HfRunner
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
RERANK_MODELS = [
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
dtype="float32",
enable_test=True),
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification",
dtype="float32",
enable_test=False)
]
class MxbaiRerankerHfRunner(HfRunner):
def __init__(self,
model_name: str,
dtype: str = "auto",
*args: Any,
**kwargs: Any) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
padding_side='left')
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
self.no_loc = self.tokenizer.convert_tokens_to_ids("0")
def predict(self, prompts: list[list[str]], *args,
**kwargs) -> torch.Tensor:
def process_inputs(pairs):
inputs = self.tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = self.tokenizer.pad(inputs,
padding=True,
return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad()
def compute_logits(inputs):
logits = self.model(**inputs).logits[:, -1, :]
yes_logits = logits[:, self.yes_loc]
no_logits = logits[:, self.no_loc]
logits = yes_logits - no_logits
scores = logits.float().sigmoid()
return scores
scores = []
for prompt in prompts:
inputs = process_inputs([prompt])
score = compute_logits(inputs)
scores.append(score[0].item())
return torch.Tensor(scores)
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "Qwen2ForSequenceClassification":
vllm_extra_kwargs["hf_overrides"] = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)

View File

@ -466,6 +466,9 @@ class ModelConfig:
"affect the random state of the Python process that " "affect the random state of the Python process that "
"launched vLLM.", self.seed) "launched vLLM.", self.seed)
# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.model = maybe_model_redirect(self.model) self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default. # The tokenizer is consistent with the model by default.
if self.tokenizer is None: if self.tokenizer is None:
@ -609,8 +612,6 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.multimodal_config = self._init_multimodal_config() self.multimodal_config = self._init_multimodal_config()
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
@ -1420,7 +1421,7 @@ class ModelConfig:
@property @property
def is_cross_encoder(self) -> bool: def is_cross_encoder(self) -> bool:
return self.registry.is_cross_encoder_model(self.architectures) return self.task == "classify"
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
@ -4762,6 +4763,12 @@ class VllmConfig:
if cls is not None: if cls is not None:
cls.verify_and_update_config(self) cls.verify_and_update_config(self)
if self.model_config.task == "classify":
# Maybe convert ForCausalLM into ForSequenceClassification model.
from vllm.model_executor.models.adapters import (
SequenceClassificationConfig)
SequenceClassificationConfig.verify_and_update_config(self)
def __str__(self): def __str__(self):
return ( return (
f"model={self.model_config.model!r}," f"model={self.model_config.model!r},"

View File

@ -2,14 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from .interfaces_base import VllmModelForPooling, is_pooling_model from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
@ -39,7 +42,6 @@ def _create_pooling_model_cls(
default_softmax: bool, default_softmax: bool,
) -> _T: ) -> _T:
# Lazy import # Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
@ -162,7 +164,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return cls return cls
# Lazy import # Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
@ -193,6 +194,7 @@ def as_seq_cls_model(cls: _T) -> _T:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.vllm_config = vllm_config
self.task = vllm_config.model_config.task self.task = vllm_config.model_config.task
self.pooling_type = ( self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type) vllm_config.model_config.pooler_config.pooling_type)
@ -242,6 +244,17 @@ def as_seq_cls_model(cls: _T) -> _T:
] ]
return PoolerOutput(outputs=pooled_outputs) return PoolerOutput(outputs=pooled_outputs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
if tokens is None and method is None:
return super().load_weights(weights)
else:
# Online convert ForCausalLM into
# ForSequenceClassification model.
return seq_cls_model_loader(self, weights)
ModelForSequenceClassification.__name__ = \ ModelForSequenceClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForSequenceClassification") _get_pooling_model_name(cls.__name__, "ForSequenceClassification")
@ -277,3 +290,86 @@ def as_reward_model(cls: _T) -> _T:
_get_pooling_model_name(cls.__name__, "ForReward") _get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore return ModelForReward # type: ignore
class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
method = getattr(config, "method", None)
tokens = getattr(config, "classifier_from_token", None)
if method is None:
return
assert tokens is not None
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
if method == "from_2_way_softmax":
assert len(tokens) == 2
config.num_labels = 1
else:
config.num_labels = len(tokens)
def load_weights_using_from_2_way_softmax(
model, weights: Iterable[tuple[str, torch.Tensor]]):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) == 2
device = model.score.weight.device
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size,
quant_config=model.quant_config)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[true_id].to(device).to(
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
torch.float32)
model.score.weight.data.copy_(weight)
del model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights
SEQ_CLS_LOAD_METHODS = {
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
}
def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
# Online convert ForCausalLM into ForSequenceClassification model.
# - from_2_way_softmax:
# - Qwen3ForCausalLM
# - Qwen3-Reranker
# - Qwen2ForCausalLM
# - mxbai-rerank-v2
config = model.vllm_config.model_config.hf_config
method = getattr(config, "method", None)
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)

View File

@ -167,7 +167,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
assert tokens is not None and len(tokens) == 2, \ assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: " ("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
config.num_labels = 1 vllm_config.model_config.hf_config.method = "from_2_way_softmax"
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):

View File

@ -38,15 +38,14 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
@ -323,114 +322,4 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights) return loader.load_weights(weights)
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)
SupportsCrossEncoding):
def __init__(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
self.prefix = prefix
self.model = Qwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(prefix, "score"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata)
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
is_original_qwen3_reranker = getattr(self.config,
"is_original_qwen3_reranker",
False)
if not is_original_qwen3_reranker:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return self.load_weights_from_original_qwen3_reranker(weights)
def load_weights_from_original_qwen3_reranker(
self, weights: Iterable[tuple[str, torch.Tensor]]):
model_config = self.vllm_config.model_config
tokens = getattr(self.config, "classifier_from_token", None)
device = self.score.weight.device
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(
self.prefix, "lm_head"))
loader = AutoWeightsLoader(self)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(
model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
a = tokenizer.convert_tokens_to_ids(tokens[0])
b = tokenizer.convert_tokens_to_ids(tokens[1])
weight = self.lm_head.weight.data[b].to(
device) - self.lm_head.weight.data[a].to(device)
self.score.weight.data.copy_(weight)
del self.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights