[Model] Re-add the implicit conversion feature for as_seq_cls_model (#21103)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-07-18 15:15:07 +08:00 committed by GitHub
parent ba2dfbb0c2
commit ca4eb82bcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 165 additions and 75 deletions

View File

@ -265,7 +265,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
@ -292,7 +291,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
@ -311,7 +309,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
@ -324,20 +321,29 @@ _EMBEDDING_EXAMPLE_MODELS = {
is_available_online=False), # noqa: E501
}
_CROSS_ENCODER_EXAMPLE_MODELS = {
# [Text-only]
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Decoder-only]
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
}
_AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
}
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
@ -449,6 +455,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
}
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
@ -489,7 +496,7 @@ _TRANSFORMERS_MODELS = {
_EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS,
**_CROSS_ENCODER_EXAMPLE_MODELS,
**_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
**_TRANSFORMERS_MODELS,
@ -522,3 +529,4 @@ class HfExampleModels:
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS)

View File

@ -13,20 +13,21 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore
from ..utils import create_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS
from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
@create_new_process_for_each_test()
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
"""The reason for using create_new_process_for_each_test is to avoid
the WARNING:
"We must use the 'spawn' multiprocessing start method. Overriding
def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
EXAMPLE_MODELS: HfExampleModels):
"""The reason for using create_new_process_for_each_test is to avoid
the WARNING:
"We must use the 'spawn' multiprocessing start method. Overriding
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
The spawn process causes the _initialize_kv_caches_v1 function below to
The spawn process causes the _initialize_kv_caches_v1 function below to
become ineffective.
"""
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info = EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
@ -127,3 +128,15 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
load_format="dummy",
hf_overrides=hf_overrides,
)
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
@pytest.mark.parametrize("model_arch",
AUTO_EXAMPLE_MODELS.get_supported_archs())
def test_implicit_converted_models(model_arch: str,
monkeypatch: pytest.MonkeyPatch):
can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS)

View File

@ -138,3 +138,38 @@ def test_quantization(
name_0="transformers",
name_1="vllm",
)
@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
monkeypatch,
) -> None:
import torch
from transformers import AutoModelForSequenceClassification
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
model_impl="transformers") as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)

View File

@ -551,7 +551,7 @@ class ModelConfig:
# For pooling models, self.task is used to indicate the
# user-selected task
if self.task == "score":
if self.registry.is_cross_encoder_model(self.architectures):
if self._is_classify_task(self.architectures):
self.task = "classify"
else:
self.task = "embed"
@ -806,6 +806,12 @@ class ModelConfig:
f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode
def _is_classify_task(self, architectures: list[str]):
for arch in architectures:
if arch.endswith("ForSequenceClassification"):
return True
return self.registry.is_cross_encoder_model(architectures)
def _get_preferred_pooling_task(
self,
architectures: list[str],
@ -813,14 +819,11 @@ class ModelConfig:
model_id = self.model
if get_pooling_config(model_id, self.revision):
return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "classify"
if self.registry.is_transcription_model(architectures):
return "transcription"
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForSequenceClassification", "classify"),
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
@ -878,11 +881,14 @@ class ModelConfig:
self,
task_option: TaskOption,
) -> dict[RunnerType, list[_ResolvedTask]]:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
if self._is_classify_task(self.architectures):
return {"generate": [], "pooling": ["classify"], "draft": []}
else:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
def _get_supported_runner_types(
self,
@ -925,12 +931,16 @@ class ModelConfig:
f"Available tasks for runner={task_runner!r}: "
f"{supported_tasks[task_runner]}")
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("ForSequenceClassification", "pooling"),
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
@ -940,10 +950,6 @@ class ModelConfig:
if arch.endswith(suffix) and pref_runner in supported_runner_types:
return pref_runner
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
if "generate" in supported_runner_types:
return "generate"
if "pooling" in supported_runner_types:
@ -1525,7 +1531,7 @@ class ModelConfig:
@property
def is_matryoshka(self) -> bool:
return (hasattr(self.hf_config, "matryoshka_dimensions")
return (bool(getattr(self.hf_config, "matryoshka_dimensions", None))
or getattr(self.hf_config, "is_matryoshka", False))
@property
@ -1539,13 +1545,11 @@ class ModelConfig:
return getattr(self.hf_config, "use_pad_token", True)
def get_and_verify_max_len(self, max_model_len: int):
# For pooling models, the tokenizer's `model_max_length` is often a
# reliable source for the maximum sequence length. However, for
# generative models, this can be incorrect and unduly limit the
# context window (e.g., DeepSeek-R1). Therefore, we only consider
# tokenizer_config for pooling models.
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
tokenizer_config = None
if self.runner_type == "pooling":
if (self.runner_type == "pooling" and getattr(
self.hf_config, "position_embedding_type", "") == "absolute"):
tokenizer_config = try_get_tokenizer_config(
self.tokenizer,
trust_remote_code=self.trust_remote_code,

View File

@ -22,7 +22,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model)
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils import is_pin_memory_available
@ -238,9 +239,29 @@ def get_model_architecture(
vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)
if vllm_not_supported:
# try automatic conversion in adapters.py
for arch in architectures:
if not arch.endswith("ForSequenceClassification"):
continue
assert model_config.task == "classify"
causal_lm_arch = arch.replace("ForSequenceClassification",
"ForCausalLM")
causal_lm_arch_vllm_supported = (causal_lm_arch
in vllm_supported_archs)
if not causal_lm_arch_vllm_supported:
continue
architectures = [causal_lm_arch]
vllm_not_supported = False
break
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures)
logger.debug_once("Resolve transformers arch %s", str(architectures))
elif (model_config.quantization is not None
and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures):
@ -248,12 +269,13 @@ def get_model_architecture(
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
logger.debug_once("Automatic conversion using `as_embedding_model`.")
model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
# Cannot automatically run as_seq_cls_model,
# otherwise it will cause a circular reference on is_cross_encoder_model
pass
logger.debug_once("Automatic conversion using `as_seq_cls_model`.")
model_cls = as_seq_cls_model(model_cls)
elif model_config.task == "reward":
logger.debug_once("Automatic conversion using `as_reward_model`.")
model_cls = as_reward_model(model_cls)
return model_cls, arch

View File

@ -331,13 +331,13 @@ def load_weights_using_from_2_way_softmax(
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[[true_id]].to(
score_weight = model.lm_head.weight.data[[true_id]].to(
torch.float32) - model.lm_head.weight.data[[false_id]].to(
torch.float32)
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
weight_loader(param, score_weight)
del model.lm_head
loaded_weights.add("score.weight")
@ -350,6 +350,8 @@ def load_weights_no_post_processing(model,
torch.Tensor]]):
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
@ -357,8 +359,6 @@ def load_weights_no_post_processing(model,
tokens = cast(list[int], tokens)
assert len(tokens) > 0
device = model.score.weight.device
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
@ -376,8 +376,11 @@ def load_weights_no_post_processing(model,
trust_remote_code=model_config.trust_remote_code)
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = model.lm_head.weight.data[token_ids].to(device)
model.score.weight.data.copy_(score_weight)
score_weight = model.lm_head.weight.data[token_ids]
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight)
del model.lm_head
loaded_weights.add("score.weight")

View File

@ -43,7 +43,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@ -426,6 +425,3 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM)

View File

@ -49,7 +49,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
@ -646,6 +645,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item])
return name, loaded_weight
LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM)

View File

@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
@ -496,6 +495,3 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)

View File

@ -44,7 +44,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
@ -320,6 +319,3 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)

View File

@ -12,7 +12,7 @@ import sys
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from functools import lru_cache
from typing import Callable, Optional, TypeVar, Union
@ -181,10 +181,6 @@ _CROSS_ENCODER_MODELS = {
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
# [Auto-converted (see adapters.py)]
"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
"LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
}
@ -462,10 +458,26 @@ class _ModelRegistry:
return _try_load_model_cls(model_arch, self.models[model_arch])
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
if model_arch not in self.models:
return None
if model_arch in self.models:
return _try_inspect_model_cls(model_arch, self.models[model_arch])
return _try_inspect_model_cls(model_arch, self.models[model_arch])
if model_arch.endswith("ForSequenceClassification"):
causal_lm_arch = model_arch.replace("ForSequenceClassification",
"ForCausalLM")
if causal_lm_arch not in self.models:
return None
info = _try_inspect_model_cls(causal_lm_arch,
self.models[causal_lm_arch])
info = _ModelInfo(**dict(
asdict(info), **{
"architecture": model_arch,
"supports_cross_encoding": True
}))
return info
return None
def _normalize_archs(
self,
@ -480,6 +492,15 @@ class _ModelRegistry:
normalized_arch = list(
filter(lambda model: model in self.models, architectures))
# try automatic conversion in adapters.py
for arch in architectures:
if not arch.endswith("ForSequenceClassification"):
continue
causal_lm_arch = arch.replace("ForSequenceClassification",
"ForCausalLM")
if causal_lm_arch in self.models:
normalized_arch.append(arch)
# make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersForCausalLM")