diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6e955e1c51213..a43803ed43333 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -22,12 +22,11 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only -from .utils import WeightsMapper, maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class BertEmbedding(nn.Module): @@ -44,9 +43,11 @@ class BertEmbedding(nn.Module): config.type_vocab_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.position_ids = nn.Parameter( - torch.empty((1, config.max_position_embeddings)), ) + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).unsqueeze(0), + ) self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": raise ValueError("Only 'absolute' position_embedding_type" + @@ -358,45 +359,45 @@ class BertModel(nn.Module, SupportsQuant): ("qkv_proj", "value", "v"), ] + loaded_stacked_params = [] + other_weights = [] params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.pooler is None and "pooler" in name: - continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue + name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + loaded_stacked_params.append(name) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) + if name in params_dict: + other_weights.append((name, loaded_weight)) + + loader = AutoWeightsLoader( + self, + skip_prefixes=(["pooler."] if self.pooler is None else []), + ) + loaded_params = loader.load_weights(other_weights) + loaded_params.update(loaded_stacked_params) return loaded_params class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): """A model that uses Bert to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -425,10 +426,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - self.model.load_weights(weights) + weights_list = list(weights) + + has_model_prefix = any( + name.startswith("model.") for name, _ in weights_list) + if not has_model_prefix: + mapper = WeightsMapper(orig_to_new_prefix={"": "model."}) + + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."]) + return loader.load_weights(weights_list, mapper=mapper) def _build_model(self, vllm_config: VllmConfig, @@ -470,26 +476,9 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, self.classifier, self.bert.pooler) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - - self_weights = [] - - def weight_filter(): - for name, weight in weights: - if name.startswith("bert."): - yield (name[len("bert."):], weight) - else: - self_weights.append((name, weight)) - - self.bert.load_weights(weight_filter()) - - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in self_weights: - if name.startswith("classifier"): - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params def pooler( self, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 048fa827fb2b9..1d3a23a5e5445 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools from collections.abc import Iterable from typing import Optional, Union @@ -13,9 +12,9 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel -from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + maybe_prefix) from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput @@ -39,8 +38,10 @@ class RobertaEmbedding(nn.Module): config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.position_ids = nn.Parameter( - torch.empty((1, config.max_position_embeddings)), ) + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).unsqueeze(0), + ) self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": @@ -136,16 +137,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel): embedding_class=RobertaEmbedding) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - # Separate weights in "roberta"-prefixed and all else (not in memory). - # For use with models like FacebookAI/roberta-base. - bert_weights, task_weights = roberta_task_weights_filter(weights) - loaded = self.model.load_weights(bert_weights) - if not len(loaded): - # Fix for models like `sentence-transformers/stsb-roberta-base-v2` - # which use the same architecture, but have no "roberta" prefix. - loaded = self.model.load_weights(task_weights) - assert len(loaded), "Unable to load RobertaEmbeddingModel" + weights_list = list(weights) + has_roberta_prefix = any( + name.startswith("roberta.") for name, _ in weights_list) + if has_roberta_prefix: + # For models with the `roberta.` prefix e.g. + # `FacebookAI/roberta-base` + mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."}) + else: + # For models without the `roberta.` prefix e.g. + # `sentence-transformers/stsb-roberta-base-v2` + mapper = WeightsMapper(orig_to_new_prefix={"": "model."}) + + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."]) + return loader.load_weights(weights_list, mapper=mapper) class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, @@ -187,19 +192,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, self.classifier) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - bert_weights, task_weights = roberta_task_weights_filter(weights) - bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) - - self.roberta.load_weights(bert_weights) - - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in task_weights: - if name.startswith("classifier"): - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) def pooler( self, @@ -245,27 +239,3 @@ def create_position_ids_from_input_ids(input_ids, past_key_values_length) * mask return incremental_indices.long() + padding_idx - - -def roberta_task_weights_filter( - all_weights: Iterable[tuple[str, torch.Tensor]] -) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, - torch.Tensor]]]: - """ - Separate task-specific weights that are applied on top - of the encoder-decoder bert base. - To do so, return two generators over the original iterator. - Also, remove the "roberta." prefix to make it loadable - from vanilla BertModel. - """ - # Copy of a lazy iterator without in-memory overhead so both - # iterators can be iterated upon independently. - all_weights1, all_weights2 = itertools.tee(all_weights) - - def encoder_decoder_weights(): - for name, weight in all_weights1: - if name.startswith("roberta."): - yield (name[len("roberta."):], weight) - - return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 - if not n.startswith("roberta."))