Fix pooling adapters for Transformers backend (#27338)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-10-24 04:23:55 +01:00 committed by GitHub
parent 70022ffc00
commit 1f9460c4c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 97 additions and 74 deletions

View File

@ -793,25 +793,20 @@ class ModelConfig:
cls += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
task = None
if defaults := try_match_architecture_defaults(self.architectures[0]):
_, (runner, convert) = defaults
# Overwrite with user-specified values
_, (runner, task) = defaults
# User specified value take precedence
if self.runner != "auto":
runner = self.runner
if self.convert not in {"auto", "none"}:
convert = self.convert
# Fall back to default values if still not set
if runner is None:
runner = "generate"
if convert in {None, "none"}:
convert = "embed"
# Resolve Transformers backend task
if runner == "pooling":
if convert == "embed":
return cls + "EmbeddingModel"
if convert == "classify":
return cls + "ForSequenceClassification"
# Only consider Transformers backend pooling classes if we're wrapping an
# architecture that defaults to pooling. Otherwise, we return the LM class
# and use adapters.
if runner == "pooling" and task in {"embed", "classify"}:
if task == "embed":
cls += "EmbeddingModel"
elif task == "classify":
cls += "ForSequenceClassification"
else:
cls += "ForCausalLM"
return cls

View File

@ -283,7 +283,6 @@ def as_seq_cls_model(cls: _T) -> _T:
Pooler,
)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
@ -291,13 +290,13 @@ def as_seq_cls_model(cls: _T) -> _T:
_create_pooling_model_cls(cls), SupportsCrossEncoding
):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
config = vllm_config.model_config.hf_config
text_config = vllm_config.model_config.hf_config.get_text_config()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
self.score = ReplicatedLinear(
model_config.hidden_size,
config.num_labels,
text_config.num_labels,
bias=False,
params_dtype=vllm_config.model_config.head_dtype,
quant_config=quant_config,
@ -322,20 +321,10 @@ def as_seq_cls_model(cls: _T) -> _T:
}
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
return super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
text_config = self.config.get_text_config()
tokens = getattr(text_config, "classifier_from_token", None)
method = getattr(text_config, "method", None)
if tokens is None and method is None:
return super().load_weights(weights)
@ -392,9 +381,9 @@ def as_reward_model(cls: _T) -> _T:
class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
method = getattr(config, "method", None)
tokens = getattr(config, "classifier_from_token", None)
text_config = vllm_config.model_config.hf_config.get_text_config()
method = getattr(text_config, "method", None)
tokens = getattr(text_config, "classifier_from_token", None)
if method is None:
return
@ -404,13 +393,13 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig):
if method == "from_2_way_softmax":
assert len(tokens) == 2
config.num_labels = 1
text_config.num_labels = 1
else:
config.num_labels = len(tokens)
text_config.num_labels = len(tokens)
# `llm as reranker` defaults to not using pad_token
use_pad_token = getattr(config, "use_pad_token", False)
config.use_pad_token = use_pad_token
use_pad_token = getattr(text_config, "use_pad_token", False)
text_config.use_pad_token = use_pad_token
def load_weights_using_from_2_way_softmax(
@ -419,24 +408,31 @@ def load_weights_using_from_2_way_softmax(
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
quant_config = model.vllm_config.quant_config
text_config = model.config.get_text_config()
tokens = getattr(model.config, "classifier_from_token", [])
tokens = getattr(text_config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) == 2
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
quant_config = model.vllm_config.quant_config
model.lm_head = ParallelLMHead(
model.config.vocab_size, model.config.hidden_size, quant_config=quant_config
model.lm_head = ParallelLMHead(
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
)
if text_config.tie_word_embeddings:
# embed_tokens is the assumed name for input embeddings. If the model does not
# have this attribute, we fallback to get_input_embeddings(), which is used by
# the Transformers backend.
embed_tokens = (
model.model.embed_tokens
if hasattr(model.model, "embed_tokens")
else model.model.get_input_embeddings()
)
model.lm_head = model.lm_head.tie_weights(embed_tokens)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -466,23 +462,31 @@ def load_weights_using_from_2_way_softmax(
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
quant_config = model.vllm_config.quant_config
text_config = model.config.get_text_config()
tokens = getattr(text_config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) > 0
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
quant_config = model.vllm_config.quant_config
model.lm_head = ParallelLMHead(
model.config.vocab_size, model.config.hidden_size, quant_config=quant_config
model.lm_head = ParallelLMHead(
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
)
if text_config.tie_word_embeddings:
# embed_tokens is the assumed name for input embeddings. If the model does not
# have this attribute, we fallback to get_input_embeddings(), which is used by
# the Transformers backend.
embed_tokens = (
model.model.embed_tokens
if hasattr(model.model, "embed_tokens")
else model.model.get_input_embeddings()
)
model.lm_head = model.lm_head.tie_weights(embed_tokens)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -523,7 +527,7 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
# - GemmaForCausalLM
# - bge-reranker-v2-gemma
config = model.vllm_config.model_config.hf_config
method = getattr(config, "method", None)
text_config = model.vllm_config.model_config.hf_config.get_text_config()
method = getattr(text_config, "method", None)
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)

View File

@ -49,6 +49,7 @@ from vllm.model_executor.models.transformers.utils import (
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
@ -92,6 +93,27 @@ ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
def __init_subclass__(cls, *args, **kwargs):
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
super().__init_subclass__(*args, **kwargs)
hf_to_vllm_mapper = WeightsMapper()
for base in cls.__mro__:
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__()

View File

@ -34,13 +34,6 @@ class LegacyMixin:
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms

View File

@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
import torch
from transformers import AutoModelForSequenceClassification
from vllm.config.utils import getattr_iter
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
@ -82,14 +83,14 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
if hasattr(module, "pooler") and module.pooler is None:
self.model.pooler = None
break
if self.model.pooler is not None:
raise ValueError(
"Sequence classification models with pooling layers are not "
"supported yet in the Transformers backend."
)
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
self.classifier = seq_cls_model.classifier
self.classifier = getattr_iter(seq_cls_model, ["classifier", "score"], None)
if self.classifier is None:
raise ValueError(
"Could not find `classifier` or `score` layer in the "
"`AutoModelForSequenceClassification` instance."
)
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
class ClassifierWithReshape(self.classifier.__class__):

View File

@ -46,6 +46,14 @@ class WeightsMapper:
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
"""Combine two `WeightsMapper`s by merging their mappings."""
return WeightsMapper(
orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
)
def _map_name(self, key: str) -> str | None:
for substr, new_key in self.orig_to_new_substr.items():
if substr in key: