mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:05:28 +08:00
Fix pooling adapters for Transformers backend (#27338)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
70022ffc00
commit
1f9460c4c1
@ -793,25 +793,20 @@ class ModelConfig:
|
|||||||
cls += "MoE" if self.get_num_experts() > 1 else ""
|
cls += "MoE" if self.get_num_experts() > 1 else ""
|
||||||
# Check if the architecture we're wrapping has defaults
|
# Check if the architecture we're wrapping has defaults
|
||||||
runner = None
|
runner = None
|
||||||
convert = None
|
task = None
|
||||||
if defaults := try_match_architecture_defaults(self.architectures[0]):
|
if defaults := try_match_architecture_defaults(self.architectures[0]):
|
||||||
_, (runner, convert) = defaults
|
_, (runner, task) = defaults
|
||||||
# Overwrite with user-specified values
|
# User specified value take precedence
|
||||||
if self.runner != "auto":
|
if self.runner != "auto":
|
||||||
runner = self.runner
|
runner = self.runner
|
||||||
if self.convert not in {"auto", "none"}:
|
# Only consider Transformers backend pooling classes if we're wrapping an
|
||||||
convert = self.convert
|
# architecture that defaults to pooling. Otherwise, we return the LM class
|
||||||
# Fall back to default values if still not set
|
# and use adapters.
|
||||||
if runner is None:
|
if runner == "pooling" and task in {"embed", "classify"}:
|
||||||
runner = "generate"
|
if task == "embed":
|
||||||
if convert in {None, "none"}:
|
cls += "EmbeddingModel"
|
||||||
convert = "embed"
|
elif task == "classify":
|
||||||
# Resolve Transformers backend task
|
cls += "ForSequenceClassification"
|
||||||
if runner == "pooling":
|
|
||||||
if convert == "embed":
|
|
||||||
return cls + "EmbeddingModel"
|
|
||||||
if convert == "classify":
|
|
||||||
return cls + "ForSequenceClassification"
|
|
||||||
else:
|
else:
|
||||||
cls += "ForCausalLM"
|
cls += "ForCausalLM"
|
||||||
return cls
|
return cls
|
||||||
|
|||||||
@ -283,7 +283,6 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
Pooler,
|
Pooler,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
@ -291,13 +290,13 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
_create_pooling_model_cls(cls), SupportsCrossEncoding
|
_create_pooling_model_cls(cls), SupportsCrossEncoding
|
||||||
):
|
):
|
||||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
config = vllm_config.model_config.hf_config
|
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.score = ReplicatedLinear(
|
self.score = ReplicatedLinear(
|
||||||
model_config.hidden_size,
|
model_config.hidden_size,
|
||||||
config.num_labels,
|
text_config.num_labels,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=vllm_config.model_config.head_dtype,
|
params_dtype=vllm_config.model_config.head_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -322,20 +321,10 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
|
||||||
inputs_embeds: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return super().forward(
|
|
||||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
tokens = getattr(self.config, "classifier_from_token", None)
|
text_config = self.config.get_text_config()
|
||||||
method = getattr(self.config, "method", None)
|
tokens = getattr(text_config, "classifier_from_token", None)
|
||||||
|
method = getattr(text_config, "method", None)
|
||||||
|
|
||||||
if tokens is None and method is None:
|
if tokens is None and method is None:
|
||||||
return super().load_weights(weights)
|
return super().load_weights(weights)
|
||||||
@ -392,9 +381,9 @@ def as_reward_model(cls: _T) -> _T:
|
|||||||
class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||||
method = getattr(config, "method", None)
|
method = getattr(text_config, "method", None)
|
||||||
tokens = getattr(config, "classifier_from_token", None)
|
tokens = getattr(text_config, "classifier_from_token", None)
|
||||||
|
|
||||||
if method is None:
|
if method is None:
|
||||||
return
|
return
|
||||||
@ -404,13 +393,13 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
|||||||
|
|
||||||
if method == "from_2_way_softmax":
|
if method == "from_2_way_softmax":
|
||||||
assert len(tokens) == 2
|
assert len(tokens) == 2
|
||||||
config.num_labels = 1
|
text_config.num_labels = 1
|
||||||
else:
|
else:
|
||||||
config.num_labels = len(tokens)
|
text_config.num_labels = len(tokens)
|
||||||
|
|
||||||
# `llm as reranker` defaults to not using pad_token
|
# `llm as reranker` defaults to not using pad_token
|
||||||
use_pad_token = getattr(config, "use_pad_token", False)
|
use_pad_token = getattr(text_config, "use_pad_token", False)
|
||||||
config.use_pad_token = use_pad_token
|
text_config.use_pad_token = use_pad_token
|
||||||
|
|
||||||
|
|
||||||
def load_weights_using_from_2_way_softmax(
|
def load_weights_using_from_2_way_softmax(
|
||||||
@ -419,24 +408,31 @@ def load_weights_using_from_2_way_softmax(
|
|||||||
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
|
||||||
|
|
||||||
model_config = model.vllm_config.model_config
|
model_config = model.vllm_config.model_config
|
||||||
|
quant_config = model.vllm_config.quant_config
|
||||||
|
text_config = model.config.get_text_config()
|
||||||
|
|
||||||
tokens = getattr(model.config, "classifier_from_token", [])
|
tokens = getattr(text_config, "classifier_from_token", [])
|
||||||
tokens = cast(list[int], tokens)
|
tokens = cast(list[int], tokens)
|
||||||
assert len(tokens) == 2
|
assert len(tokens) == 2
|
||||||
|
|
||||||
if model.config.tie_word_embeddings:
|
model.lm_head = ParallelLMHead(
|
||||||
model.lm_head = model.model.embed_tokens
|
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
|
||||||
else:
|
)
|
||||||
quant_config = model.vllm_config.quant_config
|
if text_config.tie_word_embeddings:
|
||||||
model.lm_head = ParallelLMHead(
|
# embed_tokens is the assumed name for input embeddings. If the model does not
|
||||||
model.config.vocab_size, model.config.hidden_size, quant_config=quant_config
|
# have this attribute, we fallback to get_input_embeddings(), which is used by
|
||||||
|
# the Transformers backend.
|
||||||
|
embed_tokens = (
|
||||||
|
model.model.embed_tokens
|
||||||
|
if hasattr(model.model, "embed_tokens")
|
||||||
|
else model.model.get_input_embeddings()
|
||||||
)
|
)
|
||||||
|
model.lm_head = model.lm_head.tie_weights(embed_tokens)
|
||||||
|
|
||||||
loader = AutoWeightsLoader(model)
|
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
|
||||||
loaded_weights = loader.load_weights(weights)
|
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -466,23 +462,31 @@ def load_weights_using_from_2_way_softmax(
|
|||||||
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
|
||||||
|
|
||||||
model_config = model.vllm_config.model_config
|
model_config = model.vllm_config.model_config
|
||||||
tokens = getattr(model.config, "classifier_from_token", [])
|
quant_config = model.vllm_config.quant_config
|
||||||
|
text_config = model.config.get_text_config()
|
||||||
|
|
||||||
|
tokens = getattr(text_config, "classifier_from_token", [])
|
||||||
tokens = cast(list[int], tokens)
|
tokens = cast(list[int], tokens)
|
||||||
assert len(tokens) > 0
|
assert len(tokens) > 0
|
||||||
|
|
||||||
if model.config.tie_word_embeddings:
|
model.lm_head = ParallelLMHead(
|
||||||
model.lm_head = model.model.embed_tokens
|
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
|
||||||
else:
|
)
|
||||||
quant_config = model.vllm_config.quant_config
|
if text_config.tie_word_embeddings:
|
||||||
model.lm_head = ParallelLMHead(
|
# embed_tokens is the assumed name for input embeddings. If the model does not
|
||||||
model.config.vocab_size, model.config.hidden_size, quant_config=quant_config
|
# have this attribute, we fallback to get_input_embeddings(), which is used by
|
||||||
|
# the Transformers backend.
|
||||||
|
embed_tokens = (
|
||||||
|
model.model.embed_tokens
|
||||||
|
if hasattr(model.model, "embed_tokens")
|
||||||
|
else model.model.get_input_embeddings()
|
||||||
)
|
)
|
||||||
|
model.lm_head = model.lm_head.tie_weights(embed_tokens)
|
||||||
|
|
||||||
loader = AutoWeightsLoader(model)
|
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
|
||||||
loaded_weights = loader.load_weights(weights)
|
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -523,7 +527,7 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
|||||||
# - GemmaForCausalLM
|
# - GemmaForCausalLM
|
||||||
# - bge-reranker-v2-gemma
|
# - bge-reranker-v2-gemma
|
||||||
|
|
||||||
config = model.vllm_config.model_config.hf_config
|
text_config = model.vllm_config.model_config.hf_config.get_text_config()
|
||||||
method = getattr(config, "method", None)
|
method = getattr(text_config, "method", None)
|
||||||
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
||||||
return SEQ_CLS_LOAD_METHODS[method](model, weights)
|
return SEQ_CLS_LOAD_METHODS[method](model, weights)
|
||||||
|
|||||||
@ -49,6 +49,7 @@ from vllm.model_executor.models.transformers.utils import (
|
|||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
|
WeightsMapper,
|
||||||
make_empty_intermediate_tensors_factory,
|
make_empty_intermediate_tensors_factory,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
@ -92,6 +93,27 @@ ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
|||||||
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
|
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_prefix={
|
||||||
|
# Add `model.` prefix for base model checkpoints,
|
||||||
|
# handling the case where it is already present
|
||||||
|
"": "model.",
|
||||||
|
"model.model.": "model.",
|
||||||
|
# Heads will be adjacent to `model` (pooling included because of adapters)
|
||||||
|
"model.lm_head.": "lm_head.",
|
||||||
|
"model.score.": "classifier.",
|
||||||
|
"model.classifier.": "classifier.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init_subclass__(cls, *args, **kwargs):
|
||||||
|
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
|
||||||
|
super().__init_subclass__(*args, **kwargs)
|
||||||
|
hf_to_vllm_mapper = WeightsMapper()
|
||||||
|
for base in cls.__mro__:
|
||||||
|
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
|
||||||
|
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
|
||||||
|
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -34,13 +34,6 @@ class LegacyMixin:
|
|||||||
# Handle BERT-like models
|
# Handle BERT-like models
|
||||||
"roberta": "model",
|
"roberta": "model",
|
||||||
"bert": "model",
|
"bert": "model",
|
||||||
# Add `model.` prefix for base model checkpoints
|
|
||||||
"": "model.",
|
|
||||||
# Remove `model.` prefix if it was already there
|
|
||||||
"model.model.": "model.",
|
|
||||||
# Classifier/scoring heads will be adjacent to `model`
|
|
||||||
"model.score": "classifier",
|
|
||||||
"model.classifier": "classifier",
|
|
||||||
},
|
},
|
||||||
orig_to_new_suffix={
|
orig_to_new_suffix={
|
||||||
# Replace legacy suffixes used for norms
|
# Replace legacy suffixes used for norms
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForSequenceClassification
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
from vllm.config.utils import getattr_iter
|
||||||
from vllm.model_executor.layers.pooler import (
|
from vllm.model_executor.layers.pooler import (
|
||||||
ClassifierPooler,
|
ClassifierPooler,
|
||||||
CLSPool,
|
CLSPool,
|
||||||
@ -82,14 +83,14 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
|||||||
if hasattr(module, "pooler") and module.pooler is None:
|
if hasattr(module, "pooler") and module.pooler is None:
|
||||||
self.model.pooler = None
|
self.model.pooler = None
|
||||||
break
|
break
|
||||||
if self.model.pooler is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Sequence classification models with pooling layers are not "
|
|
||||||
"supported yet in the Transformers backend."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
|
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
|
||||||
self.classifier = seq_cls_model.classifier
|
self.classifier = getattr_iter(seq_cls_model, ["classifier", "score"], None)
|
||||||
|
if self.classifier is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not find `classifier` or `score` layer in the "
|
||||||
|
"`AutoModelForSequenceClassification` instance."
|
||||||
|
)
|
||||||
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
|
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
|
||||||
|
|
||||||
class ClassifierWithReshape(self.classifier.__class__):
|
class ClassifierWithReshape(self.classifier.__class__):
|
||||||
|
|||||||
@ -46,6 +46,14 @@ class WeightsMapper:
|
|||||||
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
|
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
|
||||||
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
|
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
|
||||||
|
"""Combine two `WeightsMapper`s by merging their mappings."""
|
||||||
|
return WeightsMapper(
|
||||||
|
orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
|
||||||
|
orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
|
||||||
|
orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
|
||||||
|
)
|
||||||
|
|
||||||
def _map_name(self, key: str) -> str | None:
|
def _map_name(self, key: str) -> str | None:
|
||||||
for substr, new_key in self.orig_to_new_substr.items():
|
for substr, new_key in self.orig_to_new_substr.items():
|
||||||
if substr in key:
|
if substr in key:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user