diff --git a/vllm/config/model.py b/vllm/config/model.py index 27bcbf90c2bc..f81d324d8f80 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -793,25 +793,20 @@ class ModelConfig: cls += "MoE" if self.get_num_experts() > 1 else "" # Check if the architecture we're wrapping has defaults runner = None - convert = None + task = None if defaults := try_match_architecture_defaults(self.architectures[0]): - _, (runner, convert) = defaults - # Overwrite with user-specified values + _, (runner, task) = defaults + # User specified value take precedence if self.runner != "auto": runner = self.runner - if self.convert not in {"auto", "none"}: - convert = self.convert - # Fall back to default values if still not set - if runner is None: - runner = "generate" - if convert in {None, "none"}: - convert = "embed" - # Resolve Transformers backend task - if runner == "pooling": - if convert == "embed": - return cls + "EmbeddingModel" - if convert == "classify": - return cls + "ForSequenceClassification" + # Only consider Transformers backend pooling classes if we're wrapping an + # architecture that defaults to pooling. Otherwise, we return the LM class + # and use adapters. + if runner == "pooling" and task in {"embed", "classify"}: + if task == "embed": + cls += "EmbeddingModel" + elif task == "classify": + cls += "ForSequenceClassification" else: cls += "ForCausalLM" return cls diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 5d51cd375741..7990024c55d0 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -283,7 +283,6 @@ def as_seq_cls_model(cls: _T) -> _T: Pooler, ) from vllm.model_executor.models.interfaces import SupportsCrossEncoding - from vllm.sequence import IntermediateTensors from .utils import maybe_prefix @@ -291,13 +290,13 @@ def as_seq_cls_model(cls: _T) -> _T: _create_pooling_model_cls(cls), SupportsCrossEncoding ): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): - config = vllm_config.model_config.hf_config + text_config = vllm_config.model_config.hf_config.get_text_config() model_config = vllm_config.model_config quant_config = vllm_config.quant_config self.score = ReplicatedLinear( model_config.hidden_size, - config.num_labels, + text_config.num_labels, bias=False, params_dtype=vllm_config.model_config.head_dtype, quant_config=quant_config, @@ -322,20 +321,10 @@ def as_seq_cls_model(cls: _T) -> _T: } ) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor: - return super().forward( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - tokens = getattr(self.config, "classifier_from_token", None) - method = getattr(self.config, "method", None) + text_config = self.config.get_text_config() + tokens = getattr(text_config, "classifier_from_token", None) + method = getattr(text_config, "method", None) if tokens is None and method is None: return super().load_weights(weights) @@ -392,9 +381,9 @@ def as_reward_model(cls: _T) -> _T: class SequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config - method = getattr(config, "method", None) - tokens = getattr(config, "classifier_from_token", None) + text_config = vllm_config.model_config.hf_config.get_text_config() + method = getattr(text_config, "method", None) + tokens = getattr(text_config, "classifier_from_token", None) if method is None: return @@ -404,13 +393,13 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig): if method == "from_2_way_softmax": assert len(tokens) == 2 - config.num_labels = 1 + text_config.num_labels = 1 else: - config.num_labels = len(tokens) + text_config.num_labels = len(tokens) # `llm as reranker` defaults to not using pad_token - use_pad_token = getattr(config, "use_pad_token", False) - config.use_pad_token = use_pad_token + use_pad_token = getattr(text_config, "use_pad_token", False) + text_config.use_pad_token = use_pad_token def load_weights_using_from_2_way_softmax( @@ -419,24 +408,31 @@ def load_weights_using_from_2_way_softmax( # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config + quant_config = model.vllm_config.quant_config + text_config = model.config.get_text_config() - tokens = getattr(model.config, "classifier_from_token", []) + tokens = getattr(text_config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) == 2 - if model.config.tie_word_embeddings: - model.lm_head = model.model.embed_tokens - else: - quant_config = model.vllm_config.quant_config - model.lm_head = ParallelLMHead( - model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + model.lm_head = ParallelLMHead( + text_config.vocab_size, text_config.hidden_size, quant_config=quant_config + ) + if text_config.tie_word_embeddings: + # embed_tokens is the assumed name for input embeddings. If the model does not + # have this attribute, we fallback to get_input_embeddings(), which is used by + # the Transformers backend. + embed_tokens = ( + model.model.embed_tokens + if hasattr(model.model, "embed_tokens") + else model.model.get_input_embeddings() ) + model.lm_head = model.lm_head.tie_weights(embed_tokens) - loader = AutoWeightsLoader(model) - loaded_weights = loader.load_weights(weights) + # Skip ModelForSequenceClassification in MRO to avoid infinite recursion + loaded_weights = type(model).__mro__[1].load_weights(model, weights) from vllm.transformers_utils.tokenizer import get_tokenizer @@ -466,23 +462,31 @@ def load_weights_using_from_2_way_softmax( def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]): from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config - tokens = getattr(model.config, "classifier_from_token", []) + quant_config = model.vllm_config.quant_config + text_config = model.config.get_text_config() + + tokens = getattr(text_config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) > 0 - if model.config.tie_word_embeddings: - model.lm_head = model.model.embed_tokens - else: - quant_config = model.vllm_config.quant_config - model.lm_head = ParallelLMHead( - model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + model.lm_head = ParallelLMHead( + text_config.vocab_size, text_config.hidden_size, quant_config=quant_config + ) + if text_config.tie_word_embeddings: + # embed_tokens is the assumed name for input embeddings. If the model does not + # have this attribute, we fallback to get_input_embeddings(), which is used by + # the Transformers backend. + embed_tokens = ( + model.model.embed_tokens + if hasattr(model.model, "embed_tokens") + else model.model.get_input_embeddings() ) + model.lm_head = model.lm_head.tie_weights(embed_tokens) - loader = AutoWeightsLoader(model) - loaded_weights = loader.load_weights(weights) + # Skip ModelForSequenceClassification in MRO to avoid infinite recursion + loaded_weights = type(model).__mro__[1].load_weights(model, weights) from vllm.transformers_utils.tokenizer import get_tokenizer @@ -523,7 +527,7 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): # - GemmaForCausalLM # - bge-reranker-v2-gemma - config = model.vllm_config.model_config.hf_config - method = getattr(config, "method", None) + text_config = model.vllm_config.model_config.hf_config.get_text_config() + method = getattr(text_config, "method", None) assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" return SEQ_CLS_LOAD_METHODS[method](model, weights) diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index d940bb9739ce..41d170c9e139 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -49,6 +49,7 @@ from vllm.model_executor.models.transformers.utils import ( from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, + WeightsMapper, make_empty_intermediate_tensors_factory, maybe_prefix, ) @@ -92,6 +93,27 @@ ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Add `model.` prefix for base model checkpoints, + # handling the case where it is already present + "": "model.", + "model.model.": "model.", + # Heads will be adjacent to `model` (pooling included because of adapters) + "model.lm_head.": "lm_head.", + "model.score.": "classifier.", + "model.classifier.": "classifier.", + } + ) + + def __init_subclass__(cls, *args, **kwargs): + """Merge hf_to_vllm_mapper in MRO from most specific to least specific.""" + super().__init_subclass__(*args, **kwargs) + hf_to_vllm_mapper = WeightsMapper() + for base in cls.__mro__: + if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None): + hf_to_vllm_mapper |= base_hf_to_vllm_mapper + cls.hf_to_vllm_mapper = hf_to_vllm_mapper def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/transformers/legacy.py b/vllm/model_executor/models/transformers/legacy.py index 5d4dcf055607..a453870a2687 100644 --- a/vllm/model_executor/models/transformers/legacy.py +++ b/vllm/model_executor/models/transformers/legacy.py @@ -34,13 +34,6 @@ class LegacyMixin: # Handle BERT-like models "roberta": "model", "bert": "model", - # Add `model.` prefix for base model checkpoints - "": "model.", - # Remove `model.` prefix if it was already there - "model.model.": "model.", - # Classifier/scoring heads will be adjacent to `model` - "model.score": "classifier", - "model.classifier": "classifier", }, orig_to_new_suffix={ # Replace legacy suffixes used for norms diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py index 32aec49066fa..8117bbac013e 100644 --- a/vllm/model_executor/models/transformers/pooling.py +++ b/vllm/model_executor/models/transformers/pooling.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING import torch from transformers import AutoModelForSequenceClassification +from vllm.config.utils import getattr_iter from vllm.model_executor.layers.pooler import ( ClassifierPooler, CLSPool, @@ -82,14 +83,14 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): if hasattr(module, "pooler") and module.pooler is None: self.model.pooler = None break - if self.model.pooler is not None: - raise ValueError( - "Sequence classification models with pooling layers are not " - "supported yet in the Transformers backend." - ) # Unlike `lm_head`, `classifier` is not always `nn.Linear`. - self.classifier = seq_cls_model.classifier + self.classifier = getattr_iter(seq_cls_model, ["classifier", "score"], None) + if self.classifier is None: + raise ValueError( + "Could not find `classifier` or `score` layer in the " + "`AutoModelForSequenceClassification` instance." + ) self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) class ClassifierWithReshape(self.classifier.__class__): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index e4f540dea8f7..e86fc23c7d36 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -46,6 +46,14 @@ class WeightsMapper: orig_to_new_prefix: WeightsMapping = field(default_factory=dict) orig_to_new_suffix: WeightsMapping = field(default_factory=dict) + def __or__(self, other: "WeightsMapper") -> "WeightsMapper": + """Combine two `WeightsMapper`s by merging their mappings.""" + return WeightsMapper( + orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr}, + orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix}, + orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix}, + ) + def _map_name(self, key: str) -> str | None: for substr, new_key in self.orig_to_new_substr.items(): if substr in key: