mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 19:17:13 +08:00
[Model][Last/4] Automatic conversion of CrossEncoding model (#19675)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
1ad69e8375
commit
110df74332
@ -481,11 +481,19 @@ Specified using `--task score`.
|
||||
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|---------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | |
|
||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
|
||||
|
||||
!!! note
|
||||
Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command.
|
||||
|
||||
```bash
|
||||
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
|
||||
```
|
||||
|
||||
!!! note
|
||||
Load the official original `mxbai-rerank-v2` by using the following command.
|
||||
|
||||
|
||||
134
examples/offline_inference/convert_model_to_seq_cls.py
Normal file
134
examples/offline_inference/convert_model_to_seq_cls.py
Normal file
@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
# Usage:
|
||||
# for BAAI/bge-reranker-v2-gemma
|
||||
# Caution: "Yes" and "yes" are two different tokens
|
||||
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
|
||||
# for mxbai-rerank-v2
|
||||
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
|
||||
# for Qwen3-Reranker
|
||||
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
|
||||
|
||||
|
||||
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
|
||||
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||
assert len(tokens) == 2
|
||||
|
||||
lm_head_weights = causal_lm.lm_head.weight
|
||||
|
||||
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||
|
||||
score_weight = lm_head_weights[true_id].to(device).to(
|
||||
torch.float32
|
||||
) - lm_head_weights[false_id].to(device).to(torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
|
||||
if seq_cls_model.score.bias is not None:
|
||||
seq_cls_model.score.bias.zero_()
|
||||
|
||||
|
||||
def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
|
||||
lm_head_weights = causal_lm.lm_head.weight
|
||||
|
||||
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
|
||||
|
||||
score_weight = lm_head_weights[token_ids].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
seq_cls_model.score.weight.copy_(score_weight)
|
||||
if seq_cls_model.score.bias is not None:
|
||||
seq_cls_model.score.bias.zero_()
|
||||
|
||||
|
||||
method_map = {
|
||||
function.__name__: function for function in [from_2_way_softmax, no_post_processing]
|
||||
}
|
||||
|
||||
|
||||
def converting(
|
||||
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
|
||||
):
|
||||
assert method in method_map
|
||||
|
||||
if method == "from_2_way_softmax":
|
||||
assert len(classifier_from_tokens) == 2
|
||||
num_labels = 1
|
||||
else:
|
||||
num_labels = len(classifier_from_tokens)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map=device
|
||||
)
|
||||
|
||||
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name,
|
||||
num_labels=num_labels,
|
||||
ignore_mismatched_sizes=True,
|
||||
device_map=device,
|
||||
)
|
||||
|
||||
method_map[method](
|
||||
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
|
||||
)
|
||||
|
||||
# `llm as reranker` defaults to not using pad_token
|
||||
seq_cls_model.config.use_pad_token = use_pad_token
|
||||
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
seq_cls_model.save_pretrained(path)
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Converting *ForCausalLM models to "
|
||||
"*ForSequenceClassification models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="BAAI/bge-reranker-v2-gemma",
|
||||
help="Model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classifier_from_tokens",
|
||||
type=str,
|
||||
default='["Yes"]',
|
||||
help="classifier from tokens",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method", type=str, default="no_post_processing", help="Converting converting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-pad-token", action="store_true", help="Whether to use pad_token"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path",
|
||||
type=str,
|
||||
default="./bge-reranker-v2-gemma-seq-cls",
|
||||
help="Path to save converted model",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
converting(
|
||||
model_name=args.model_name,
|
||||
classifier_from_tokens=json.loads(args.classifier_from_tokens),
|
||||
method=args.method,
|
||||
use_pad_token=args.use_pad_token,
|
||||
path=args.path,
|
||||
)
|
||||
@ -267,7 +267,8 @@ def mteb_test_rerank_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: RerankModelInfo,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None):
|
||||
hf_model_callback=None,
|
||||
vllm_mteb_encoder=VllmMtebEncoder):
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
@ -288,7 +289,7 @@ def mteb_test_rerank_models(hf_runner,
|
||||
assert (model_info.architecture in model_config.architectures)
|
||||
assert model_config.hf_config.num_labels == 1
|
||||
|
||||
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
|
||||
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
vllm_dtype = model_config.dtype
|
||||
|
||||
140
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
Normal file
140
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
Normal file
@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
|
||||
from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
|
||||
mteb_test_rerank_models)
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||
architecture="GemmaForSequenceClassification"),
|
||||
]
|
||||
|
||||
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
||||
|
||||
|
||||
class GemmaRerankerHfRunner(HfRunner):
|
||||
|
||||
def __init__(self,
|
||||
model_name: str,
|
||||
dtype: str = "auto",
|
||||
*args: Any,
|
||||
**kwargs: Any) -> None:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
padding_side='left')
|
||||
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes")
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(self, prompts: list[list[str]], *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
def get_inputs(pairs, tokenizer, prompt=None):
|
||||
if prompt is None:
|
||||
prompt = PROMPT
|
||||
|
||||
sep = "\n"
|
||||
prompt_inputs = tokenizer(prompt,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
sep_inputs = tokenizer(sep,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
inputs = []
|
||||
for query, passage in pairs:
|
||||
query_inputs = tokenizer(
|
||||
f"A: {query}",
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
truncation=True,
|
||||
)
|
||||
passage_inputs = tokenizer(
|
||||
f"B: {passage}",
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
truncation=True,
|
||||
)
|
||||
item = tokenizer.prepare_for_model(
|
||||
[tokenizer.bos_token_id] + query_inputs["input_ids"],
|
||||
sep_inputs + passage_inputs["input_ids"],
|
||||
truncation="only_second",
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
item["input_ids"] = item[
|
||||
"input_ids"] + sep_inputs + prompt_inputs
|
||||
item["attention_mask"] = [1] * len(item["input_ids"])
|
||||
inputs.append(item)
|
||||
return tokenizer.pad(
|
||||
inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
scores = []
|
||||
for query, doc, *_ in prompts:
|
||||
pairs = [(query, doc)]
|
||||
inputs = get_inputs(pairs, self.tokenizer)
|
||||
inputs = inputs.to(self.model.device)
|
||||
_n_tokens = inputs["input_ids"].shape[1]
|
||||
logits = self.model(**inputs, return_dict=True).logits
|
||||
_scores = (logits[:, -1,
|
||||
self.yes_loc].view(-1, ).float().sigmoid())
|
||||
scores.append(_scores[0].item())
|
||||
return torch.Tensor(scores)
|
||||
|
||||
|
||||
class GemmaMtebEncoder(VllmMtebEncoder):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt = PROMPT
|
||||
self.query_template = "A: {query}\n"
|
||||
self.document_template = "B: {doc}\n{prompt}"
|
||||
|
||||
def predict(
|
||||
self,
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
|
||||
_sentences = []
|
||||
for query, corpus, prompt in sentences:
|
||||
query = self.query_template.format(query=query)
|
||||
corpus = self.document_template.format(doc=corpus, prompt=prompt)
|
||||
_sentences.append((query, corpus, prompt))
|
||||
|
||||
return super().predict(_sentences, *args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
|
||||
monkeypatch) -> None:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
assert model_info.architecture == "GemmaForSequenceClassification"
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {
|
||||
"hf_overrides": {
|
||||
"architectures": ["GemmaForSequenceClassification"],
|
||||
"classifier_from_token": ["Yes"],
|
||||
"method": "no_post_processing",
|
||||
}
|
||||
}
|
||||
|
||||
mteb_test_rerank_models(GemmaRerankerHfRunner,
|
||||
vllm_runner,
|
||||
model_info,
|
||||
vllm_extra_kwargs,
|
||||
vllm_mteb_encoder=GemmaMtebEncoder)
|
||||
@ -12,11 +12,9 @@ from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
dtype="float32",
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
|
||||
@ -319,9 +319,14 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
||||
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
|
||||
v0_only=True,
|
||||
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
|
||||
"classifier_from_token": ["Yes"], # noqa: E501
|
||||
"method": "no_post_processing"}), # noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
|
||||
@ -1449,6 +1449,12 @@ class ModelConfig:
|
||||
def matryoshka_dimensions(self):
|
||||
return getattr(self.hf_config, "matryoshka_dimensions", None)
|
||||
|
||||
@property
|
||||
def use_pad_token(self) -> bool:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
return getattr(self.hf_config, "use_pad_token", True)
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
# For pooling models, the tokenizer's `model_max_length` is often a
|
||||
# reliable source for the maximum sequence length. However, for
|
||||
|
||||
@ -1205,7 +1205,6 @@ class LLM:
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
|
||||
pooling_params = PoolingParams(use_cross_encoder=True)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
@ -1213,9 +1212,14 @@ class LLM:
|
||||
parsed_prompts = []
|
||||
|
||||
for q, t in input_pairs:
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
if self.llm_engine.model_config.use_pad_token:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
prompt_inputs = tokenizer(text=q + t, **tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
@ -167,12 +167,22 @@ class ServingScores(OpenAIServing):
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs)
|
||||
for t1, t2 in input_pairs))
|
||||
use_pad_token = self.model_config.use_pad_token
|
||||
|
||||
if use_pad_token:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs)
|
||||
for t1, t2 in input_pairs))
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(text=t1 + t2, **tokenization_kwargs)
|
||||
for t1, t2 in input_pairs))
|
||||
|
||||
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
|
||||
sep_token = tokenizer.sep_token if tokenizer.sep_token else ''
|
||||
sep_token = tokenizer.sep_token if (tokenizer.sep_token
|
||||
and use_pad_token) else ''
|
||||
request_prompt = f"{t1}{sep_token}{t2}"
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
|
||||
@ -312,6 +312,10 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
else:
|
||||
config.num_labels = len(tokens)
|
||||
|
||||
# `llm as reranker` defaults to not using pad_token
|
||||
use_pad_token = getattr(config, "use_pad_token", False)
|
||||
config.use_pad_token = use_pad_token
|
||||
|
||||
|
||||
def load_weights_using_from_2_way_softmax(
|
||||
model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@ -356,8 +360,49 @@ def load_weights_using_from_2_way_softmax(
|
||||
return loaded_weights
|
||||
|
||||
|
||||
def load_weights_no_post_processing(model,
|
||||
weights: Iterable[tuple[str,
|
||||
torch.Tensor]]):
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
model_config = model.vllm_config.model_config
|
||||
tokens = getattr(model.config, "classifier_from_token", [])
|
||||
tokens = cast(list[int], tokens)
|
||||
assert len(tokens) > 0
|
||||
|
||||
device = model.score.weight.device
|
||||
|
||||
if model.config.tie_word_embeddings:
|
||||
model.lm_head = model.model.embed_tokens
|
||||
else:
|
||||
model.lm_head = ParallelLMHead(model.config.vocab_size,
|
||||
model.config.hidden_size,
|
||||
quant_config=model.quant_config)
|
||||
|
||||
loader = AutoWeightsLoader(model)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
tokenizer = get_tokenizer(model_config.tokenizer,
|
||||
revision=model_config.tokenizer_revision,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
|
||||
score_weight = model.lm_head.weight.data[token_ids].to(device)
|
||||
model.score.weight.data.copy_(score_weight)
|
||||
|
||||
del model.lm_head
|
||||
loaded_weights.add("score.weight")
|
||||
loaded_weights.discard("lm_head.weight")
|
||||
return loaded_weights
|
||||
|
||||
|
||||
SEQ_CLS_LOAD_METHODS = {
|
||||
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
|
||||
"no_post_processing": load_weights_no_post_processing,
|
||||
}
|
||||
|
||||
|
||||
@ -368,6 +413,9 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# - Qwen3-Reranker
|
||||
# - Qwen2ForCausalLM
|
||||
# - mxbai-rerank-v2
|
||||
# - no_post_processing:
|
||||
# - GemmaForCausalLM
|
||||
# - bge-reranker-v2-gemma
|
||||
|
||||
config = model.vllm_config.model_config.hf_config
|
||||
method = getattr(config, "method", None)
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@ -425,3 +426,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM)
|
||||
|
||||
@ -179,8 +179,9 @@ _CROSS_ENCODER_MODELS = {
|
||||
"ModernBertForSequenceClassification": ("modernbert",
|
||||
"ModernBertForSequenceClassification"),
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user