From bd4c1e6fdbec56594079764bcb74c7e2a81ce525 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Sun, 13 Jul 2025 16:09:34 +0900 Subject: [PATCH] Support for LlamaForSequenceClassification (#20807) Signed-off-by: thechaos16 --- tests/models/registry.py | 1 + vllm/model_executor/models/llama.py | 4 ++++ vllm/model_executor/models/registry.py | 3 ++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index c10d375683eea..1207a928c92f3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -330,6 +330,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = { hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501 "method": "no_post_processing"}), # noqa: E501 + "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 48ec611df12dd..2434ac9d205da 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -645,3 +646,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): name = name.replace(item, mapping[item]) return name, loaded_weight + + +LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e8530a555d286..b7d4789549aa0 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -183,7 +183,8 @@ _CROSS_ENCODER_MODELS = { "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501 "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 - "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501 + "LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501 + "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } _MULTIMODAL_MODELS = {