[Model] Re-add the implicit conversion feature for as_seq_cls_model (#21103)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-07-18 15:15:07 +08:00 committed by GitHub
parent ba2dfbb0c2
commit ca4eb82bcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 165 additions and 75 deletions

View File

@ -265,7 +265,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
@ -292,7 +291,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True), trust_remote_code=True),
@ -311,7 +309,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
@ -324,20 +321,29 @@ _EMBEDDING_EXAMPLE_MODELS = {
is_available_online=False), # noqa: E501 is_available_online=False), # noqa: E501
} }
_CROSS_ENCODER_EXAMPLE_MODELS = { _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Text-only] # [Decoder-only]
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
} }
_AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
}
_MULTIMODAL_EXAMPLE_MODELS = { _MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only] # [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
@ -449,6 +455,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
} }
_SPECULATIVE_DECODING_EXAMPLE_MODELS = { _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m", "EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501 speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
@ -489,7 +496,7 @@ _TRANSFORMERS_MODELS = {
_EXAMPLE_MODELS = { _EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS, **_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS,
**_CROSS_ENCODER_EXAMPLE_MODELS, **_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS,
**_TRANSFORMERS_MODELS, **_TRANSFORMERS_MODELS,
@ -522,3 +529,4 @@ class HfExampleModels:
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS)

View File

@ -13,20 +13,21 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore from vllm.v1.engine.core import EngineCore as V1EngineCore
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
"""The reason for using create_new_process_for_each_test is to avoid EXAMPLE_MODELS: HfExampleModels):
the WARNING: """The reason for using create_new_process_for_each_test is to avoid
"We must use the 'spawn' multiprocessing start method. Overriding the WARNING:
"We must use the 'spawn' multiprocessing start method. Overriding
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'." VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
The spawn process causes the _initialize_kv_caches_v1 function below to The spawn process causes the _initialize_kv_caches_v1 function below to
become ineffective. become ineffective.
""" """
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info = EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
@ -127,3 +128,15 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
load_format="dummy", load_format="dummy",
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
) )
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
@pytest.mark.parametrize("model_arch",
AUTO_EXAMPLE_MODELS.get_supported_archs())
def test_implicit_converted_models(model_arch: str,
monkeypatch: pytest.MonkeyPatch):
can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS)

View File

@ -138,3 +138,38 @@ def test_quantization(
name_0="transformers", name_0="transformers",
name_1="vllm", name_1="vllm",
) )
@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
monkeypatch,
) -> None:
import torch
from transformers import AutoModelForSequenceClassification
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
model_impl="transformers") as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)

View File

@ -551,7 +551,7 @@ class ModelConfig:
# For pooling models, self.task is used to indicate the # For pooling models, self.task is used to indicate the
# user-selected task # user-selected task
if self.task == "score": if self.task == "score":
if self.registry.is_cross_encoder_model(self.architectures): if self._is_classify_task(self.architectures):
self.task = "classify" self.task = "classify"
else: else:
self.task = "embed" self.task = "embed"
@ -806,6 +806,12 @@ class ModelConfig:
f"one of {get_args(TokenizerMode)}.") f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _is_classify_task(self, architectures: list[str]):
for arch in architectures:
if arch.endswith("ForSequenceClassification"):
return True
return self.registry.is_cross_encoder_model(architectures)
def _get_preferred_pooling_task( def _get_preferred_pooling_task(
self, self,
architectures: list[str], architectures: list[str],
@ -813,14 +819,11 @@ class ModelConfig:
model_id = self.model model_id = self.model
if get_pooling_config(model_id, self.revision): if get_pooling_config(model_id, self.revision):
return "embed" return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "classify"
if self.registry.is_transcription_model(architectures): if self.registry.is_transcription_model(architectures):
return "transcription" return "transcription"
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern # Other models follow this pattern
("ForSequenceClassification", "classify"),
("EmbeddingModel", "embed"), ("EmbeddingModel", "embed"),
("RewardModel", "reward"), ("RewardModel", "reward"),
] ]
@ -878,11 +881,14 @@ class ModelConfig:
self, self,
task_option: TaskOption, task_option: TaskOption,
) -> dict[RunnerType, list[_ResolvedTask]]: ) -> dict[RunnerType, list[_ResolvedTask]]:
return { if self._is_classify_task(self.architectures):
"generate": self._get_supported_generation_tasks(task_option), return {"generate": [], "pooling": ["classify"], "draft": []}
"pooling": self._get_supported_pooling_tasks(task_option), else:
"draft": ["draft"] return {
} "generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
def _get_supported_runner_types( def _get_supported_runner_types(
self, self,
@ -925,12 +931,16 @@ class ModelConfig:
f"Available tasks for runner={task_runner!r}: " f"Available tasks for runner={task_runner!r}: "
f"{supported_tasks[task_runner]}") f"{supported_tasks[task_runner]}")
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [ suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
("ForCausalLM", "generate"), ("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"), ("ForConditionalGeneration", "generate"),
("ChatModel", "generate"), ("ChatModel", "generate"),
("LMHeadModel", "generate"), ("LMHeadModel", "generate"),
("ForSequenceClassification", "pooling"),
("EmbeddingModel", "pooling"), ("EmbeddingModel", "pooling"),
("RewardModel", "pooling"), ("RewardModel", "pooling"),
] ]
@ -940,10 +950,6 @@ class ModelConfig:
if arch.endswith(suffix) and pref_runner in supported_runner_types: if arch.endswith(suffix) and pref_runner in supported_runner_types:
return pref_runner return pref_runner
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
if "generate" in supported_runner_types: if "generate" in supported_runner_types:
return "generate" return "generate"
if "pooling" in supported_runner_types: if "pooling" in supported_runner_types:
@ -1525,7 +1531,7 @@ class ModelConfig:
@property @property
def is_matryoshka(self) -> bool: def is_matryoshka(self) -> bool:
return (hasattr(self.hf_config, "matryoshka_dimensions") return (bool(getattr(self.hf_config, "matryoshka_dimensions", None))
or getattr(self.hf_config, "is_matryoshka", False)) or getattr(self.hf_config, "is_matryoshka", False))
@property @property
@ -1539,13 +1545,11 @@ class ModelConfig:
return getattr(self.hf_config, "use_pad_token", True) return getattr(self.hf_config, "use_pad_token", True)
def get_and_verify_max_len(self, max_model_len: int): def get_and_verify_max_len(self, max_model_len: int):
# For pooling models, the tokenizer's `model_max_length` is often a # Consider max_model_len in tokenizer_config only when
# reliable source for the maximum sequence length. However, for # pooling models use absolute position_embedding.
# generative models, this can be incorrect and unduly limit the
# context window (e.g., DeepSeek-R1). Therefore, we only consider
# tokenizer_config for pooling models.
tokenizer_config = None tokenizer_config = None
if self.runner_type == "pooling": if (self.runner_type == "pooling" and getattr(
self.hf_config, "position_embedding_type", "") == "absolute"):
tokenizer_config = try_get_tokenizer_config( tokenizer_config = try_get_tokenizer_config(
self.tokenizer, self.tokenizer,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,

View File

@ -22,7 +22,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_embedding_model, from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model) as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -238,9 +239,29 @@ def get_model_architecture(
vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures) for arch in architectures)
if vllm_not_supported:
# try automatic conversion in adapters.py
for arch in architectures:
if not arch.endswith("ForSequenceClassification"):
continue
assert model_config.task == "classify"
causal_lm_arch = arch.replace("ForSequenceClassification",
"ForCausalLM")
causal_lm_arch_vllm_supported = (causal_lm_arch
in vllm_supported_archs)
if not causal_lm_arch_vllm_supported:
continue
architectures = [causal_lm_arch]
vllm_not_supported = False
break
if (model_config.model_impl == ModelImpl.TRANSFORMERS or if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures) architectures = resolve_transformers_arch(model_config, architectures)
logger.debug_once("Resolve transformers arch %s", str(architectures))
elif (model_config.quantization is not None elif (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures): and "MixtralForCausalLM" in architectures):
@ -248,12 +269,13 @@ def get_model_architecture(
model_cls, arch = ModelRegistry.resolve_model_cls(architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed": if model_config.task == "embed":
logger.debug_once("Automatic conversion using `as_embedding_model`.")
model_cls = as_embedding_model(model_cls) model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify": elif model_config.task == "classify":
# Cannot automatically run as_seq_cls_model, logger.debug_once("Automatic conversion using `as_seq_cls_model`.")
# otherwise it will cause a circular reference on is_cross_encoder_model model_cls = as_seq_cls_model(model_cls)
pass
elif model_config.task == "reward": elif model_config.task == "reward":
logger.debug_once("Automatic conversion using `as_reward_model`.")
model_cls = as_reward_model(model_cls) model_cls = as_reward_model(model_cls)
return model_cls, arch return model_cls, arch

View File

@ -331,13 +331,13 @@ def load_weights_using_from_2_way_softmax(
false_id = tokenizer.convert_tokens_to_ids(tokens[0]) false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1]) true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[[true_id]].to( score_weight = model.lm_head.weight.data[[true_id]].to(
torch.float32) - model.lm_head.weight.data[[false_id]].to( torch.float32) - model.lm_head.weight.data[[false_id]].to(
torch.float32) torch.float32)
param = model.score.weight param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight) weight_loader(param, score_weight)
del model.lm_head del model.lm_head
loaded_weights.add("score.weight") loaded_weights.add("score.weight")
@ -350,6 +350,8 @@ def load_weights_no_post_processing(model,
torch.Tensor]]): torch.Tensor]]):
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead) ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config model_config = model.vllm_config.model_config
@ -357,8 +359,6 @@ def load_weights_no_post_processing(model,
tokens = cast(list[int], tokens) tokens = cast(list[int], tokens)
assert len(tokens) > 0 assert len(tokens) > 0
device = model.score.weight.device
if model.config.tie_word_embeddings: if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens model.lm_head = model.model.embed_tokens
else: else:
@ -376,8 +376,11 @@ def load_weights_no_post_processing(model,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code)
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = model.lm_head.weight.data[token_ids].to(device) score_weight = model.lm_head.weight.data[token_ids]
model.score.weight.data.copy_(score_weight)
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight)
del model.lm_head del model.lm_head
loaded_weights.add("score.weight") loaded_weights.add("score.weight")

View File

@ -43,7 +43,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
@ -426,6 +425,3 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM)

View File

@ -49,7 +49,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
@ -646,6 +645,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight
LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM)

View File

@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
@ -496,6 +495,3 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)

View File

@ -44,7 +44,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
@ -320,6 +319,3 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)

View File

@ -12,7 +12,7 @@ import sys
import tempfile import tempfile
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Set from collections.abc import Set
from dataclasses import dataclass, field from dataclasses import asdict, dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Callable, Optional, TypeVar, Union from typing import Callable, Optional, TypeVar, Union
@ -181,10 +181,6 @@ _CROSS_ENCODER_MODELS = {
"ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"), "ModernBertForSequenceClassification"),
# [Auto-converted (see adapters.py)] # [Auto-converted (see adapters.py)]
"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
"LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
} }
@ -462,10 +458,26 @@ class _ModelRegistry:
return _try_load_model_cls(model_arch, self.models[model_arch]) return _try_load_model_cls(model_arch, self.models[model_arch])
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
if model_arch not in self.models: if model_arch in self.models:
return None return _try_inspect_model_cls(model_arch, self.models[model_arch])
return _try_inspect_model_cls(model_arch, self.models[model_arch]) if model_arch.endswith("ForSequenceClassification"):
causal_lm_arch = model_arch.replace("ForSequenceClassification",
"ForCausalLM")
if causal_lm_arch not in self.models:
return None
info = _try_inspect_model_cls(causal_lm_arch,
self.models[causal_lm_arch])
info = _ModelInfo(**dict(
asdict(info), **{
"architecture": model_arch,
"supports_cross_encoding": True
}))
return info
return None
def _normalize_archs( def _normalize_archs(
self, self,
@ -480,6 +492,15 @@ class _ModelRegistry:
normalized_arch = list( normalized_arch = list(
filter(lambda model: model in self.models, architectures)) filter(lambda model: model in self.models, architectures))
# try automatic conversion in adapters.py
for arch in architectures:
if not arch.endswith("ForSequenceClassification"):
continue
causal_lm_arch = arch.replace("ForSequenceClassification",
"ForCausalLM")
if causal_lm_arch in self.models:
normalized_arch.append(arch)
# make sure Transformers backend is put at the last as a fallback # make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures): if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersForCausalLM") normalized_arch.append("TransformersForCausalLM")