mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 16:22:16 +08:00
[New Model]: Support Qwen3 Embedding & Reranker (#19260)
This commit is contained in:
parent
77f0d465d0
commit
3952731e8f
@ -387,18 +387,19 @@ See [this page](./pooling_models.md) for more information on how to use pooling
|
|||||||
|
|
||||||
Specified using `--task embed`.
|
Specified using `--task embed`.
|
||||||
|
|
||||||
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|
||||||
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|
|
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|
|
||||||
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
|
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
|
||||||
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
|
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
|
||||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
||||||
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | |
|
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | |
|
||||||
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ |
|
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ |
|
||||||
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ |
|
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ |
|
||||||
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ |
|
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ |
|
||||||
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
|
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
|
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
|
||||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |
|
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ |
|
||||||
|
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||||
@ -450,12 +451,19 @@ If your model is not in the above list, we will try to automatically convert the
|
|||||||
|
|
||||||
Specified using `--task score`.
|
Specified using `--task score`.
|
||||||
|
|
||||||
| Architecture | Models | Example HF Models |
|
| Architecture | Models | Example HF Models |
|
||||||
|---------------------------------------|-------------------|----------------------------------------------|
|
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|
|
||||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
|
||||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
|
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. |
|
||||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |
|
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
|
||||||
|
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||||
|
```
|
||||||
[](){ #supported-mm-models }
|
[](){ #supported-mm-models }
|
||||||
|
|
||||||
## List of Multimodal Language Models
|
## List of Multimodal Language Models
|
||||||
|
|||||||
77
examples/offline_inference/qwen3_reranker.py
Normal file
77
examples/offline_inference/qwen3_reranker.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
model_name = "Qwen/Qwen3-Reranker-0.6B"
|
||||||
|
|
||||||
|
# What is the difference between the official original version and one
|
||||||
|
# that has been converted into a sequence classification model?
|
||||||
|
# Qwen3-Reranker is a language model that doing reranker by using the
|
||||||
|
# logits of "no" and "yes" tokens.
|
||||||
|
# It needs to computing 151669 tokens logits, making this method extremely
|
||||||
|
# inefficient, not to mention incompatible with the vllm score API.
|
||||||
|
# A method for converting the original model into a sequence classification
|
||||||
|
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||||
|
# Models converted offline using this method can not only be more efficient
|
||||||
|
# and support the vllm score API, but also make the init parameters more
|
||||||
|
# concise, for example.
|
||||||
|
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")
|
||||||
|
|
||||||
|
# If you want to load the official original version, the init parameters are
|
||||||
|
# as follows.
|
||||||
|
|
||||||
|
model = LLM(
|
||||||
|
model=model_name,
|
||||||
|
task="score",
|
||||||
|
hf_overrides={
|
||||||
|
"architectures": ["Qwen3ForSequenceClassification"],
|
||||||
|
"classifier_from_token": ["no", "yes"],
|
||||||
|
"is_original_qwen3_reranker": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Why do we need hf_overrides for the official original version:
|
||||||
|
# vllm converts it to Qwen3ForSequenceClassification when loaded for
|
||||||
|
# better performance.
|
||||||
|
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
|
||||||
|
# to manually route to Qwen3ForSequenceClassification.
|
||||||
|
# - Then, we will extract the vector corresponding to classifier_from_token
|
||||||
|
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
|
||||||
|
# - Third, we will convert these two vectors into one vector. The use of
|
||||||
|
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
|
||||||
|
|
||||||
|
# Please use the query_template and document_template to format the query and
|
||||||
|
# document for better reranker results.
|
||||||
|
|
||||||
|
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
|
||||||
|
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
|
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
|
||||||
|
document_template = "<Document>: {doc}{suffix}"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
instruction = (
|
||||||
|
"Given a web search query, retrieve relevant passages that answer the query"
|
||||||
|
)
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"What is the capital of China?",
|
||||||
|
"Explain gravity",
|
||||||
|
]
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
"The capital of China is Beijing.",
|
||||||
|
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
|
||||||
|
]
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
query_template.format(prefix=prefix, instruction=instruction, query=query)
|
||||||
|
for query in queries
|
||||||
|
]
|
||||||
|
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
|
||||||
|
|
||||||
|
outputs = model.score(queries, documents)
|
||||||
|
|
||||||
|
print([output.outputs.score for output in outputs])
|
||||||
@ -45,6 +45,15 @@ MODELS = [
|
|||||||
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
||||||
architecture="ModernBertModel",
|
architecture="ModernBertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
|
########## Qwen3ForCausalLM
|
||||||
|
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
architecture="Qwen3ForCausalLM",
|
||||||
|
dtype="float32",
|
||||||
|
enable_test=True),
|
||||||
|
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
|
||||||
|
architecture="Qwen3ForCausalLM",
|
||||||
|
dtype="float32",
|
||||||
|
enable_test=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
87
tests/models/language/pooling/test_qwen3_reranker.py
Normal file
87
tests/models/language/pooling/test_qwen3_reranker.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
model_name = "Qwen/Qwen3-Reranker-4B"
|
||||||
|
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
texts_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.",
|
||||||
|
"The capital of France is Paris.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_reranker(model_name):
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
model = LLM(model=model_name,
|
||||||
|
task="score",
|
||||||
|
hf_overrides={
|
||||||
|
"architectures": ["Qwen3ForSequenceClassification"],
|
||||||
|
"classifier_from_token": ["no", "yes"],
|
||||||
|
"is_original_qwen3_reranker": True,
|
||||||
|
},
|
||||||
|
dtype="float32")
|
||||||
|
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
texts_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.",
|
||||||
|
"The capital of France is Paris.",
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = model.score(text_1, texts_2)
|
||||||
|
|
||||||
|
return [output.outputs.score for output in outputs]
|
||||||
|
|
||||||
|
|
||||||
|
def hf_reranker(model_name):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
|
||||||
|
|
||||||
|
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||||
|
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||||
|
|
||||||
|
max_length = 8192
|
||||||
|
|
||||||
|
def process_inputs(pairs):
|
||||||
|
inputs = tokenizer(pairs,
|
||||||
|
padding=False,
|
||||||
|
truncation='longest_first',
|
||||||
|
return_attention_mask=False,
|
||||||
|
max_length=max_length)
|
||||||
|
for i, ele in enumerate(inputs['input_ids']):
|
||||||
|
inputs['input_ids'][i] = ele
|
||||||
|
inputs = tokenizer.pad(inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
max_length=max_length)
|
||||||
|
for key in inputs:
|
||||||
|
inputs[key] = inputs[key].to(model.device)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def compute_logits(inputs, **kwargs):
|
||||||
|
batch_scores = model(**inputs).logits[:, -1, :]
|
||||||
|
true_vector = batch_scores[:, token_true_id]
|
||||||
|
false_vector = batch_scores[:, token_false_id]
|
||||||
|
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||||
|
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||||
|
scores = batch_scores[:, 1].exp().tolist()
|
||||||
|
return scores
|
||||||
|
|
||||||
|
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
|
||||||
|
inputs = process_inputs(pairs)
|
||||||
|
scores = compute_logits(inputs)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [model_name])
|
||||||
|
def test_model(model_name):
|
||||||
|
hf_outputs = hf_reranker(model_name)
|
||||||
|
vllm_outputs = vllm_reranker(model_name)
|
||||||
|
|
||||||
|
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||||
|
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
|
||||||
73
tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
Normal file
73
tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
|
||||||
|
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
texts_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.",
|
||||||
|
"The capital of France is Paris.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_reranker(model_name):
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
model = LLM(model=model_name, task="score")
|
||||||
|
outputs = model.score(text_1, texts_2)
|
||||||
|
|
||||||
|
return [output.outputs.score for output in outputs]
|
||||||
|
|
||||||
|
|
||||||
|
def hf_reranker(model_name):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
|
||||||
|
|
||||||
|
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||||
|
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||||
|
|
||||||
|
max_length = 8192
|
||||||
|
|
||||||
|
def process_inputs(pairs):
|
||||||
|
inputs = tokenizer(pairs,
|
||||||
|
padding=False,
|
||||||
|
truncation='longest_first',
|
||||||
|
return_attention_mask=False,
|
||||||
|
max_length=max_length)
|
||||||
|
for i, ele in enumerate(inputs['input_ids']):
|
||||||
|
inputs['input_ids'][i] = ele
|
||||||
|
inputs = tokenizer.pad(inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
max_length=max_length)
|
||||||
|
for key in inputs:
|
||||||
|
inputs[key] = inputs[key].to(model.device)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def compute_logits(inputs, **kwargs):
|
||||||
|
batch_scores = model(**inputs).logits[:, -1, :]
|
||||||
|
true_vector = batch_scores[:, token_true_id]
|
||||||
|
false_vector = batch_scores[:, token_false_id]
|
||||||
|
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||||
|
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||||
|
scores = batch_scores[:, 1].exp().tolist()
|
||||||
|
return scores
|
||||||
|
|
||||||
|
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
|
||||||
|
inputs = process_inputs(pairs)
|
||||||
|
scores = compute_logits(inputs)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [model_name])
|
||||||
|
def test_model(model_name):
|
||||||
|
hf_outputs = hf_reranker(model_name)
|
||||||
|
vllm_outputs = vllm_reranker(model_name)
|
||||||
|
|
||||||
|
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||||
|
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
|
||||||
@ -238,6 +238,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||||
|
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
|
||||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
|
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
|
||||||
v0_only=True),
|
v0_only=True),
|
||||||
|
|||||||
@ -38,13 +38,15 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
|
||||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||||
from .qwen2 import Qwen2Model
|
from .qwen2 import Qwen2Model
|
||||||
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
|
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
|
||||||
@ -319,3 +321,122 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
if self.config.tie_word_embeddings else None),
|
if self.config.tie_word_embeddings else None),
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
||||||
|
SupportsCrossEncoding):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.prefix = prefix
|
||||||
|
self.model = Qwen3Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
self.score = RowParallelLinear(config.hidden_size,
|
||||||
|
config.num_labels,
|
||||||
|
quant_config=quant_config,
|
||||||
|
input_is_parallel=False,
|
||||||
|
bias=False,
|
||||||
|
prefix=maybe_prefix(prefix, "score"))
|
||||||
|
|
||||||
|
self._pooler = Pooler.from_config_with_defaults(
|
||||||
|
pooler_config,
|
||||||
|
pooling_type=PoolingType.LAST,
|
||||||
|
normalize=False,
|
||||||
|
softmax=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
hidden_states = self._pooler.extract_states(hidden_states,
|
||||||
|
pooling_metadata)
|
||||||
|
logits, _ = self.score(hidden_states)
|
||||||
|
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||||
|
pooled_outputs = [
|
||||||
|
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
|
||||||
|
]
|
||||||
|
return PoolerOutput(outputs=pooled_outputs)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
is_original_qwen3_reranker = getattr(self.config,
|
||||||
|
"is_original_qwen3_reranker",
|
||||||
|
False)
|
||||||
|
|
||||||
|
if not is_original_qwen3_reranker:
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
return self.load_weights_from_original_qwen3_reranker(weights)
|
||||||
|
|
||||||
|
def load_weights_from_original_qwen3_reranker(
|
||||||
|
self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
tokens = getattr(self.config, "classifier_from_token", None)
|
||||||
|
assert tokens is not None and len(tokens) == 2, \
|
||||||
|
("Try loading the original Qwen3 Reranker?, see: "
|
||||||
|
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
||||||
|
|
||||||
|
self.config.num_labels = 1
|
||||||
|
model_config = self.vllm_config.model_config
|
||||||
|
|
||||||
|
device = self.score.weight.device
|
||||||
|
self.score = RowParallelLinear(self.config.hidden_size,
|
||||||
|
self.config.num_labels,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
input_is_parallel=False,
|
||||||
|
bias=False,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
self.prefix, "score")).to(device)
|
||||||
|
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
self.prefix, "lm_head"))
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
loaded_weights = loader.load_weights(weights)
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model_config.tokenizer,
|
||||||
|
revision=model_config.tokenizer_revision,
|
||||||
|
tokenizer_mode=model_config.tokenizer_mode,
|
||||||
|
trust_remote_code=model_config.trust_remote_code)
|
||||||
|
|
||||||
|
a = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||||
|
b = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||||
|
weight = self.lm_head.weight.data[b].to(
|
||||||
|
device) - self.lm_head.weight.data[a].to(device)
|
||||||
|
self.score.weight.data.copy_(weight)
|
||||||
|
|
||||||
|
del self.lm_head
|
||||||
|
loaded_weights.add("classifier.weight")
|
||||||
|
loaded_weights.discard("lm_head.weight")
|
||||||
|
|||||||
@ -172,6 +172,7 @@ _CROSS_ENCODER_MODELS = {
|
|||||||
"RobertaForSequenceClassification"),
|
"RobertaForSequenceClassification"),
|
||||||
"ModernBertForSequenceClassification": ("modernbert",
|
"ModernBertForSequenceClassification": ("modernbert",
|
||||||
"ModernBertForSequenceClassification"),
|
"ModernBertForSequenceClassification"),
|
||||||
|
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
_MULTIMODAL_MODELS = {
|
_MULTIMODAL_MODELS = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user