mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 22:01:54 +08:00
[Model] Re-add the implicit conversion feature for as_seq_cls_model (#21103)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
ba2dfbb0c2
commit
ca4eb82bcb
@ -265,7 +265,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
@ -292,7 +291,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
|
||||
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
|
||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
trust_remote_code=True),
|
||||
@ -311,7 +309,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
|
||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
|
||||
@ -324,20 +321,29 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
is_available_online=False), # noqa: E501
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
|
||||
|
||||
# [Cross-encoder]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
||||
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
|
||||
v0_only=True,
|
||||
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
|
||||
"classifier_from_token": ["Yes"], # noqa: E501
|
||||
"method": "no_post_processing"}), # noqa: E501
|
||||
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||
}
|
||||
|
||||
_AUTOMATIC_CONVERTED_MODELS = {
|
||||
# Use as_seq_cls_model for automatic conversion
|
||||
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
|
||||
v0_only=True,
|
||||
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
|
||||
"classifier_from_token": ["Yes"], # noqa: E501
|
||||
"method": "no_post_processing"}), # noqa: E501
|
||||
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
|
||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||
@ -449,6 +455,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
|
||||
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
|
||||
@ -489,7 +496,7 @@ _TRANSFORMERS_MODELS = {
|
||||
_EXAMPLE_MODELS = {
|
||||
**_TEXT_GENERATION_EXAMPLE_MODELS,
|
||||
**_EMBEDDING_EXAMPLE_MODELS,
|
||||
**_CROSS_ENCODER_EXAMPLE_MODELS,
|
||||
**_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
|
||||
**_MULTIMODAL_EXAMPLE_MODELS,
|
||||
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
||||
**_TRANSFORMERS_MODELS,
|
||||
@ -522,3 +529,4 @@ class HfExampleModels:
|
||||
|
||||
|
||||
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
|
||||
AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS)
|
||||
|
||||
@ -13,20 +13,21 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||
@create_new_process_for_each_test()
|
||||
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
"""The reason for using create_new_process_for_each_test is to avoid
|
||||
the WARNING:
|
||||
"We must use the 'spawn' multiprocessing start method. Overriding
|
||||
def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
EXAMPLE_MODELS: HfExampleModels):
|
||||
"""The reason for using create_new_process_for_each_test is to avoid
|
||||
the WARNING:
|
||||
"We must use the 'spawn' multiprocessing start method. Overriding
|
||||
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
|
||||
The spawn process causes the _initialize_kv_caches_v1 function below to
|
||||
The spawn process causes the _initialize_kv_caches_v1 function below to
|
||||
become ineffective.
|
||||
"""
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
|
||||
model_info = EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
@ -127,3 +128,15 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
load_format="dummy",
|
||||
hf_overrides=hf_overrides,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arch",
|
||||
AUTO_EXAMPLE_MODELS.get_supported_archs())
|
||||
def test_implicit_converted_models(model_arch: str,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS)
|
||||
|
||||
@ -138,3 +138,38 @@ def test_quantization(
|
||||
name_0="transformers",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_classify(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
model_impl="transformers") as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output)
|
||||
vllm_output = torch.tensor(vllm_output)
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
@ -551,7 +551,7 @@ class ModelConfig:
|
||||
# For pooling models, self.task is used to indicate the
|
||||
# user-selected task
|
||||
if self.task == "score":
|
||||
if self.registry.is_cross_encoder_model(self.architectures):
|
||||
if self._is_classify_task(self.architectures):
|
||||
self.task = "classify"
|
||||
else:
|
||||
self.task = "embed"
|
||||
@ -806,6 +806,12 @@ class ModelConfig:
|
||||
f"one of {get_args(TokenizerMode)}.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _is_classify_task(self, architectures: list[str]):
|
||||
for arch in architectures:
|
||||
if arch.endswith("ForSequenceClassification"):
|
||||
return True
|
||||
return self.registry.is_cross_encoder_model(architectures)
|
||||
|
||||
def _get_preferred_pooling_task(
|
||||
self,
|
||||
architectures: list[str],
|
||||
@ -813,14 +819,11 @@ class ModelConfig:
|
||||
model_id = self.model
|
||||
if get_pooling_config(model_id, self.revision):
|
||||
return "embed"
|
||||
if self.registry.is_cross_encoder_model(architectures):
|
||||
return "classify"
|
||||
if self.registry.is_transcription_model(architectures):
|
||||
return "transcription"
|
||||
|
||||
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
|
||||
# Other models follow this pattern
|
||||
("ForSequenceClassification", "classify"),
|
||||
("EmbeddingModel", "embed"),
|
||||
("RewardModel", "reward"),
|
||||
]
|
||||
@ -878,11 +881,14 @@ class ModelConfig:
|
||||
self,
|
||||
task_option: TaskOption,
|
||||
) -> dict[RunnerType, list[_ResolvedTask]]:
|
||||
return {
|
||||
"generate": self._get_supported_generation_tasks(task_option),
|
||||
"pooling": self._get_supported_pooling_tasks(task_option),
|
||||
"draft": ["draft"]
|
||||
}
|
||||
if self._is_classify_task(self.architectures):
|
||||
return {"generate": [], "pooling": ["classify"], "draft": []}
|
||||
else:
|
||||
return {
|
||||
"generate": self._get_supported_generation_tasks(task_option),
|
||||
"pooling": self._get_supported_pooling_tasks(task_option),
|
||||
"draft": ["draft"]
|
||||
}
|
||||
|
||||
def _get_supported_runner_types(
|
||||
self,
|
||||
@ -925,12 +931,16 @@ class ModelConfig:
|
||||
f"Available tasks for runner={task_runner!r}: "
|
||||
f"{supported_tasks[task_runner]}")
|
||||
|
||||
if "classify" in supported_tasks.get("pooling", []):
|
||||
# When multiple pooling tasks are present, default to
|
||||
# pooling (eg cross-encoder) for non-standard architectures.
|
||||
return "pooling"
|
||||
|
||||
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("ForSequenceClassification", "pooling"),
|
||||
("EmbeddingModel", "pooling"),
|
||||
("RewardModel", "pooling"),
|
||||
]
|
||||
@ -940,10 +950,6 @@ class ModelConfig:
|
||||
if arch.endswith(suffix) and pref_runner in supported_runner_types:
|
||||
return pref_runner
|
||||
|
||||
if "classify" in supported_tasks.get("pooling", []):
|
||||
# When multiple pooling tasks are present, default to
|
||||
# pooling (eg cross-encoder) for non-standard architectures.
|
||||
return "pooling"
|
||||
if "generate" in supported_runner_types:
|
||||
return "generate"
|
||||
if "pooling" in supported_runner_types:
|
||||
@ -1525,7 +1531,7 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def is_matryoshka(self) -> bool:
|
||||
return (hasattr(self.hf_config, "matryoshka_dimensions")
|
||||
return (bool(getattr(self.hf_config, "matryoshka_dimensions", None))
|
||||
or getattr(self.hf_config, "is_matryoshka", False))
|
||||
|
||||
@property
|
||||
@ -1539,13 +1545,11 @@ class ModelConfig:
|
||||
return getattr(self.hf_config, "use_pad_token", True)
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
# For pooling models, the tokenizer's `model_max_length` is often a
|
||||
# reliable source for the maximum sequence length. However, for
|
||||
# generative models, this can be incorrect and unduly limit the
|
||||
# context window (e.g., DeepSeek-R1). Therefore, we only consider
|
||||
# tokenizer_config for pooling models.
|
||||
# Consider max_model_len in tokenizer_config only when
|
||||
# pooling models use absolute position_embedding.
|
||||
tokenizer_config = None
|
||||
if self.runner_type == "pooling":
|
||||
if (self.runner_type == "pooling" and getattr(
|
||||
self.hf_config, "position_embedding_type", "") == "absolute"):
|
||||
tokenizer_config = try_get_tokenizer_config(
|
||||
self.tokenizer,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
|
||||
@ -22,7 +22,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.adapters import (as_embedding_model,
|
||||
as_reward_model)
|
||||
as_reward_model,
|
||||
as_seq_cls_model)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
@ -238,9 +239,29 @@ def get_model_architecture(
|
||||
vllm_supported_archs = ModelRegistry.get_supported_archs()
|
||||
vllm_not_supported = not any(arch in vllm_supported_archs
|
||||
for arch in architectures)
|
||||
|
||||
if vllm_not_supported:
|
||||
# try automatic conversion in adapters.py
|
||||
for arch in architectures:
|
||||
if not arch.endswith("ForSequenceClassification"):
|
||||
continue
|
||||
|
||||
assert model_config.task == "classify"
|
||||
causal_lm_arch = arch.replace("ForSequenceClassification",
|
||||
"ForCausalLM")
|
||||
causal_lm_arch_vllm_supported = (causal_lm_arch
|
||||
in vllm_supported_archs)
|
||||
if not causal_lm_arch_vllm_supported:
|
||||
continue
|
||||
|
||||
architectures = [causal_lm_arch]
|
||||
vllm_not_supported = False
|
||||
break
|
||||
|
||||
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
|
||||
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
|
||||
architectures = resolve_transformers_arch(model_config, architectures)
|
||||
logger.debug_once("Resolve transformers arch %s", str(architectures))
|
||||
elif (model_config.quantization is not None
|
||||
and model_config.quantization not in mixtral_supported
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
@ -248,12 +269,13 @@ def get_model_architecture(
|
||||
|
||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||
if model_config.task == "embed":
|
||||
logger.debug_once("Automatic conversion using `as_embedding_model`.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif model_config.task == "classify":
|
||||
# Cannot automatically run as_seq_cls_model,
|
||||
# otherwise it will cause a circular reference on is_cross_encoder_model
|
||||
pass
|
||||
logger.debug_once("Automatic conversion using `as_seq_cls_model`.")
|
||||
model_cls = as_seq_cls_model(model_cls)
|
||||
elif model_config.task == "reward":
|
||||
logger.debug_once("Automatic conversion using `as_reward_model`.")
|
||||
model_cls = as_reward_model(model_cls)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
@ -331,13 +331,13 @@ def load_weights_using_from_2_way_softmax(
|
||||
|
||||
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||
weight = model.lm_head.weight.data[[true_id]].to(
|
||||
score_weight = model.lm_head.weight.data[[true_id]].to(
|
||||
torch.float32) - model.lm_head.weight.data[[false_id]].to(
|
||||
torch.float32)
|
||||
|
||||
param = model.score.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
weight_loader(param, score_weight)
|
||||
|
||||
del model.lm_head
|
||||
loaded_weights.add("score.weight")
|
||||
@ -350,6 +350,8 @@ def load_weights_no_post_processing(model,
|
||||
torch.Tensor]]):
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
model_config = model.vllm_config.model_config
|
||||
@ -357,8 +359,6 @@ def load_weights_no_post_processing(model,
|
||||
tokens = cast(list[int], tokens)
|
||||
assert len(tokens) > 0
|
||||
|
||||
device = model.score.weight.device
|
||||
|
||||
if model.config.tie_word_embeddings:
|
||||
model.lm_head = model.model.embed_tokens
|
||||
else:
|
||||
@ -376,8 +376,11 @@ def load_weights_no_post_processing(model,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
|
||||
score_weight = model.lm_head.weight.data[token_ids].to(device)
|
||||
model.score.weight.data.copy_(score_weight)
|
||||
score_weight = model.lm_head.weight.data[token_ids]
|
||||
|
||||
param = model.score.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, score_weight)
|
||||
|
||||
del model.lm_head
|
||||
loaded_weights.add("score.weight")
|
||||
|
||||
@ -43,7 +43,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@ -426,6 +425,3 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM)
|
||||
|
||||
@ -49,7 +49,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
@ -646,6 +645,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM)
|
||||
|
||||
@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
@ -496,6 +495,3 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)
|
||||
|
||||
@ -44,7 +44,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
@ -320,6 +319,3 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)
|
||||
|
||||
@ -12,7 +12,7 @@ import sys
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
|
||||
@ -181,10 +181,6 @@ _CROSS_ENCODER_MODELS = {
|
||||
"ModernBertForSequenceClassification": ("modernbert",
|
||||
"ModernBertForSequenceClassification"),
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||
"LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
||||
}
|
||||
|
||||
@ -462,10 +458,26 @@ class _ModelRegistry:
|
||||
return _try_load_model_cls(model_arch, self.models[model_arch])
|
||||
|
||||
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
|
||||
if model_arch not in self.models:
|
||||
return None
|
||||
if model_arch in self.models:
|
||||
return _try_inspect_model_cls(model_arch, self.models[model_arch])
|
||||
|
||||
return _try_inspect_model_cls(model_arch, self.models[model_arch])
|
||||
if model_arch.endswith("ForSequenceClassification"):
|
||||
causal_lm_arch = model_arch.replace("ForSequenceClassification",
|
||||
"ForCausalLM")
|
||||
if causal_lm_arch not in self.models:
|
||||
return None
|
||||
|
||||
info = _try_inspect_model_cls(causal_lm_arch,
|
||||
self.models[causal_lm_arch])
|
||||
|
||||
info = _ModelInfo(**dict(
|
||||
asdict(info), **{
|
||||
"architecture": model_arch,
|
||||
"supports_cross_encoding": True
|
||||
}))
|
||||
return info
|
||||
|
||||
return None
|
||||
|
||||
def _normalize_archs(
|
||||
self,
|
||||
@ -480,6 +492,15 @@ class _ModelRegistry:
|
||||
normalized_arch = list(
|
||||
filter(lambda model: model in self.models, architectures))
|
||||
|
||||
# try automatic conversion in adapters.py
|
||||
for arch in architectures:
|
||||
if not arch.endswith("ForSequenceClassification"):
|
||||
continue
|
||||
causal_lm_arch = arch.replace("ForSequenceClassification",
|
||||
"ForCausalLM")
|
||||
if causal_lm_arch in self.models:
|
||||
normalized_arch.append(arch)
|
||||
|
||||
# make sure Transformers backend is put at the last as a fallback
|
||||
if len(normalized_arch) != len(architectures):
|
||||
normalized_arch.append("TransformersForCausalLM")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user