mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[Model][3/N] Automatic conversion of CrossEncoding model (#20168)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
9e5452ee34
commit
2e26f9156a
@ -477,12 +477,20 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
|
||||
Specified using `--task score`.
|
||||
|
||||
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|
||||
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
|
||||
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|
||||
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|---------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
|
||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
|
||||
|
||||
!!! note
|
||||
Load the official original `mxbai-rerank-v2` by using the following command.
|
||||
|
||||
```bash
|
||||
vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}'
|
||||
```
|
||||
|
||||
!!! note
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.
|
||||
@ -490,6 +498,7 @@ Specified using `--task score`.
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||
```
|
||||
|
||||
[](){ #supported-mm-models }
|
||||
|
||||
## List of Multimodal Language Models
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -74,6 +75,13 @@ def test_models(
|
||||
vllm_extra_kwargs["override_pooler_config"] = \
|
||||
PoolerConfig(pooling_type="MEAN", normalize=False)
|
||||
|
||||
max_model_len: Optional[int] = 512
|
||||
if model in [
|
||||
"sentence-transformers/all-MiniLM-L12-v2",
|
||||
"sentence-transformers/stsb-roberta-base-v2"
|
||||
]:
|
||||
max_model_len = None
|
||||
|
||||
# The example_prompts has ending "\n", for example:
|
||||
# "Write a short story about a robot that dreams for the first time.\n"
|
||||
# sentence_transformers will strip the input texts, see:
|
||||
@ -87,7 +95,7 @@ def test_models(
|
||||
|
||||
with vllm_runner(model,
|
||||
task="embed",
|
||||
max_model_len=512,
|
||||
max_model_len=max_model_len,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
|
||||
@ -56,10 +56,16 @@ MODELS = [
|
||||
enable_test=False),
|
||||
]
|
||||
|
||||
V1FlashAttentionImpNotSupported = [
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-modernbert-base"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo) -> None:
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo,
|
||||
monkeypatch) -> None:
|
||||
if model_info.name in V1FlashAttentionImpNotSupported:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "GteNewModel":
|
||||
@ -71,8 +77,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_embed_models_correctness(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo,
|
||||
example_prompts) -> None:
|
||||
model_info: EmbedModelInfo, example_prompts,
|
||||
monkeypatch) -> None:
|
||||
if model_info.name in V1FlashAttentionImpNotSupported:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "GteNewModel":
|
||||
|
||||
84
tests/models/language/pooling/test_mxbai_rerank.py
Normal file
84
tests/models/language/pooling/test_mxbai_rerank.py
Normal file
@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
|
||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
dtype="float32",
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
|
||||
class MxbaiRerankerHfRunner(HfRunner):
|
||||
|
||||
def __init__(self,
|
||||
model_name: str,
|
||||
dtype: str = "auto",
|
||||
*args: Any,
|
||||
**kwargs: Any) -> None:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
padding_side='left')
|
||||
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
|
||||
self.no_loc = self.tokenizer.convert_tokens_to_ids("0")
|
||||
|
||||
def predict(self, prompts: list[list[str]], *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
def process_inputs(pairs):
|
||||
inputs = self.tokenizer(pairs,
|
||||
padding=False,
|
||||
truncation='longest_first',
|
||||
return_attention_mask=False)
|
||||
for i, ele in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = ele
|
||||
inputs = self.tokenizer.pad(inputs,
|
||||
padding=True,
|
||||
return_tensors="pt")
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.model.device)
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs):
|
||||
logits = self.model(**inputs).logits[:, -1, :]
|
||||
yes_logits = logits[:, self.yes_loc]
|
||||
no_logits = logits[:, self.no_loc]
|
||||
logits = yes_logits - no_logits
|
||||
scores = logits.float().sigmoid()
|
||||
return scores
|
||||
|
||||
scores = []
|
||||
for prompt in prompts:
|
||||
inputs = process_inputs([prompt])
|
||||
score = compute_logits(inputs)
|
||||
scores.append(score[0].item())
|
||||
return torch.Tensor(scores)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "Qwen2ForSequenceClassification":
|
||||
vllm_extra_kwargs["hf_overrides"] = {
|
||||
"architectures": ["Qwen2ForSequenceClassification"],
|
||||
"classifier_from_token": ["0", "1"],
|
||||
"method": "from_2_way_softmax",
|
||||
}
|
||||
|
||||
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
|
||||
vllm_extra_kwargs)
|
||||
@ -466,6 +466,9 @@ class ModelConfig:
|
||||
"affect the random state of the Python process that "
|
||||
"launched vLLM.", self.seed)
|
||||
|
||||
# Keep set served_model_name before maybe_model_redirect(self.model)
|
||||
self.served_model_name = get_served_model_name(self.model,
|
||||
self.served_model_name)
|
||||
self.model = maybe_model_redirect(self.model)
|
||||
# The tokenizer is consistent with the model by default.
|
||||
if self.tokenizer is None:
|
||||
@ -609,8 +612,6 @@ class ModelConfig:
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
self.served_model_name = get_served_model_name(self.model,
|
||||
self.served_model_name)
|
||||
self.multimodal_config = self._init_multimodal_config()
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
@ -1420,7 +1421,7 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def is_cross_encoder(self) -> bool:
|
||||
return self.registry.is_cross_encoder_model(self.architectures)
|
||||
return self.task == "classify"
|
||||
|
||||
@property
|
||||
def use_mla(self) -> bool:
|
||||
@ -4762,6 +4763,12 @@ class VllmConfig:
|
||||
if cls is not None:
|
||||
cls.verify_and_update_config(self)
|
||||
|
||||
if self.model_config.task == "classify":
|
||||
# Maybe convert ForCausalLM into ForSequenceClassification model.
|
||||
from vllm.model_executor.models.adapters import (
|
||||
SequenceClassificationConfig)
|
||||
SequenceClassificationConfig.verify_and_update_config(self)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r},"
|
||||
|
||||
@ -2,14 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
||||
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
@ -39,7 +42,6 @@ def _create_pooling_model_cls(
|
||||
default_softmax: bool,
|
||||
) -> _T:
|
||||
# Lazy import
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
|
||||
@ -162,7 +164,6 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
@ -193,6 +194,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.task = vllm_config.model_config.task
|
||||
self.pooling_type = (
|
||||
vllm_config.model_config.pooler_config.pooling_type)
|
||||
@ -242,6 +244,17 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
tokens = getattr(self.config, "classifier_from_token", None)
|
||||
method = getattr(self.config, "method", None)
|
||||
|
||||
if tokens is None and method is None:
|
||||
return super().load_weights(weights)
|
||||
else:
|
||||
# Online convert ForCausalLM into
|
||||
# ForSequenceClassification model.
|
||||
return seq_cls_model_loader(self, weights)
|
||||
|
||||
|
||||
ModelForSequenceClassification.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
|
||||
@ -277,3 +290,86 @@ def as_reward_model(cls: _T) -> _T:
|
||||
_get_pooling_model_name(cls.__name__, "ForReward")
|
||||
|
||||
return ModelForReward # type: ignore
|
||||
|
||||
|
||||
class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
method = getattr(config, "method", None)
|
||||
tokens = getattr(config, "classifier_from_token", None)
|
||||
|
||||
if method is None:
|
||||
return
|
||||
|
||||
assert tokens is not None
|
||||
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
||||
|
||||
if method == "from_2_way_softmax":
|
||||
assert len(tokens) == 2
|
||||
config.num_labels = 1
|
||||
else:
|
||||
config.num_labels = len(tokens)
|
||||
|
||||
|
||||
def load_weights_using_from_2_way_softmax(
|
||||
model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
model_config = model.vllm_config.model_config
|
||||
tokens = getattr(model.config, "classifier_from_token", [])
|
||||
tokens = cast(list[int], tokens)
|
||||
assert len(tokens) == 2
|
||||
|
||||
device = model.score.weight.device
|
||||
|
||||
if model.config.tie_word_embeddings:
|
||||
model.lm_head = model.model.embed_tokens
|
||||
else:
|
||||
model.lm_head = ParallelLMHead(model.config.vocab_size,
|
||||
model.config.hidden_size,
|
||||
quant_config=model.quant_config)
|
||||
|
||||
loader = AutoWeightsLoader(model)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
tokenizer = get_tokenizer(model_config.tokenizer,
|
||||
revision=model_config.tokenizer_revision,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||
weight = model.lm_head.weight.data[true_id].to(device).to(
|
||||
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
|
||||
torch.float32)
|
||||
model.score.weight.data.copy_(weight)
|
||||
|
||||
del model.lm_head
|
||||
loaded_weights.add("score.weight")
|
||||
loaded_weights.discard("lm_head.weight")
|
||||
return loaded_weights
|
||||
|
||||
|
||||
SEQ_CLS_LOAD_METHODS = {
|
||||
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
|
||||
}
|
||||
|
||||
|
||||
def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# Online convert ForCausalLM into ForSequenceClassification model.
|
||||
# - from_2_way_softmax:
|
||||
# - Qwen3ForCausalLM
|
||||
# - Qwen3-Reranker
|
||||
# - Qwen2ForCausalLM
|
||||
# - mxbai-rerank-v2
|
||||
|
||||
config = model.vllm_config.model_config.hf_config
|
||||
method = getattr(config, "method", None)
|
||||
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
||||
return SEQ_CLS_LOAD_METHODS[method](model, weights)
|
||||
|
||||
@ -167,7 +167,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
assert tokens is not None and len(tokens) == 2, \
|
||||
("Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
||||
config.num_labels = 1
|
||||
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
|
||||
|
||||
|
||||
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@ -38,15 +38,14 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
|
||||
@ -323,114 +322,4 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
self.model = Qwen3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"))
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
hidden_states = self._pooler.extract_states(hidden_states,
|
||||
pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
logits = [self.score(state)[0] for state in hidden_states]
|
||||
else:
|
||||
logits, _ = self.score(hidden_states)
|
||||
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
is_original_qwen3_reranker = getattr(self.config,
|
||||
"is_original_qwen3_reranker",
|
||||
False)
|
||||
|
||||
if not is_original_qwen3_reranker:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
return self.load_weights_from_original_qwen3_reranker(weights)
|
||||
|
||||
def load_weights_from_original_qwen3_reranker(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
model_config = self.vllm_config.model_config
|
||||
tokens = getattr(self.config, "classifier_from_token", None)
|
||||
device = self.score.weight.device
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(
|
||||
self.prefix, "lm_head"))
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
revision=model_config.tokenizer_revision,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
a = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||
b = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||
weight = self.lm_head.weight.data[b].to(
|
||||
device) - self.lm_head.weight.data[a].to(device)
|
||||
self.score.weight.data.copy_(weight)
|
||||
|
||||
del self.lm_head
|
||||
loaded_weights.add("score.weight")
|
||||
loaded_weights.discard("lm_head.weight")
|
||||
return loaded_weights
|
||||
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user