diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py new file mode 100644 index 0000000000000..5f20452aff3d8 --- /dev/null +++ b/tests/models/multimodal/test_mapping.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import pytest +import torch +import transformers +from transformers import AutoConfig, PreTrainedModel + +from vllm.config import ModelConfig +from vllm.model_executor.models.utils import WeightsMapper +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.transformers_utils.config import try_get_safetensors_metadata + +from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS + + +def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]: + """Create weights from safetensors checkpoint metadata""" + metadata = try_get_safetensors_metadata(repo) + weight_names = list(metadata.weight_map.keys()) + with torch.device('meta'): + return ((name, torch.empty(0)) for name in weight_names) + + +def create_model_dummy_weights( + repo: str, + model_arch: str, +) -> Iterable[tuple[str, torch.Tensor]]: + """ + Create weights from a dummy meta deserialized hf model with name conversion + """ + model_cls: PreTrainedModel = getattr(transformers, model_arch) + config = AutoConfig.from_pretrained(repo) + with torch.device("meta"): + model: PreTrainedModel = model_cls._from_config(config) + return model.named_parameters() + + +def model_architectures_for_test() -> list[str]: + arch_to_test = list[str]() + for model_arch, info in _MULTIMODAL_EXAMPLE_MODELS.items(): + if not info.trust_remote_code and hasattr(transformers, model_arch): + model_cls: PreTrainedModel = getattr(transformers, model_arch) + if getattr(model_cls, "_checkpoint_conversion_mapping", None): + arch_to_test.append(model_arch) + return arch_to_test + + +@pytest.mark.core_model +@pytest.mark.parametrize("model_arch", model_architectures_for_test()) +def test_hf_model_weights_mapper(model_arch: str): + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + + model_id = model_info.default + + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_info.tokenizer or model_id, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + seed=0, + dtype="auto", + revision=None, + hf_overrides=model_info.hf_overrides, + ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + + original_weights = create_repo_dummy_weights(model_id) + hf_converted_weights = create_model_dummy_weights(model_id, model_arch) + mapper: WeightsMapper = model_cls.hf_to_vllm_mapper + + mapped_original_weights = mapper.apply(original_weights) + mapped_hf_converted_weights = mapper.apply(hf_converted_weights) + + ref_weight_names = set(map(lambda x: x[0], mapped_original_weights)) + weight_names = set(map(lambda x: x[0], mapped_hf_converted_weights)) + + weights_missing = ref_weight_names - weight_names + weights_unmapped = weight_names - ref_weight_names + assert (not weights_missing and not weights_unmapped), ( + f"Following weights are not mapped correctly: {weights_unmapped}, " + f"Missing expected weights: {weights_missing}.") diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index b69c7b6a9376d..4fe6a7b9e9383 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -486,6 +486,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + # mapping for original checkpoint "language_model.model": "language_model", "language_model.lm_head": "lm_head", }, diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 6a95ac089ff4a..7c02d245db8b7 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -32,8 +32,9 @@ from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) class AyaVisionImagePixelInputs(TypedDict): @@ -292,6 +293,15 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: AyaVisionConfig = vllm_config.model_config.hf_config @@ -323,7 +333,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 462f85c3dd623..9692899f7b993 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -42,7 +42,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, merge_multimodal_embeddings) # Cannot find the following 2 numbers from hf config. @@ -245,6 +245,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): dummy_inputs=FuyuDummyInputsBuilder) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_embed_tokens.": "vision_embed_tokens.", + "model.language_model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index b633c0003c637..415a8dbdcf87f 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -35,8 +35,9 @@ from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -471,6 +472,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ], } + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -697,7 +707,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7dea260a58e0d..f70ad37a3d3ac 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -40,8 +40,9 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vision_encoder_info @@ -499,6 +500,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): "gate_up_proj": ["gate_proj", "up_proj"] } + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -754,7 +764,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class MantisProcessingInfo(LlavaProcessingInfo): diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 60ede454ff272..bc792be19dbf6 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -26,8 +26,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, + flatten_bn, init_vllm_registered_model, maybe_prefix) class LlavaNextImagePixelInputs(TypedDict): @@ -205,6 +205,16 @@ class LlavaNextMultiModalProcessor( class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -583,4 +593,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 78084465e7a27..c13e8e9b24140 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -29,8 +29,9 @@ from vllm.utils import is_list_of from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vision_encoder_info @@ -270,6 +271,16 @@ class LlavaNextMultiModalProjector(nn.Module): class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -468,4 +479,4 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, # This model doesn't support images for now ignore_unexpected_prefixes=["image_newline"], ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 265f63d7bd295..373b0a2a7d5e6 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -30,8 +30,9 @@ from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, LlavaNextProcessingInfo) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -428,6 +429,16 @@ class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -954,4 +965,4 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 59deacffd2851..ebc176e2c7242 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -36,8 +36,9 @@ from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vision_encoder_info @@ -389,6 +390,15 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, "gate_up_proj": ["gate_proj", "up_proj"] } + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -592,7 +602,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index e9f91feb3359d..1b7e93fafad93 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -67,7 +67,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only from .llama import LlamaDecoderLayer, LlamaMLP -from .utils import maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = init_logger(__name__) @@ -790,6 +790,36 @@ class MllamaVisionModel(nn.Module): dim=-1) return hidden_state + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + updated_params: set[str] = set() + for name, loaded_weight in weights: + if 'patch_embedding._linear.weight' in name: + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + updated_params.add(name) + return updated_params + class MllamaTextRMSNorm(nn.Module): @@ -1132,6 +1162,7 @@ class MllamaForCausalLM(nn.Module): config = vllm_config.model_config.hf_config.text_config quant_config = vllm_config.quant_config + self.quant_config = quant_config self.vocab_size = config.vocab_size self.model = MllamaTextModel(vllm_config=vllm_config, @@ -1167,6 +1198,58 @@ class MllamaForCausalLM(nn.Module): ) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params: set[str] = set() + for name, loaded_weight in weights: + if 'patch_embedding.weight' in name: + name = name.replace('patch_embedding.weight', + 'patch_embedding._linear.weight') + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + updated_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + orig_name = name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + logger.debug("Missing name %s, orig name %s", name, + orig_name) + continue + + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + updated_params.add(name) + return updated_params + @MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, info=MllamaProcessingInfo, @@ -1178,6 +1261,19 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, "gate_up_proj": ["gate_proj", "up_proj"] } + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.vision_model.": "vision_model.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.language_model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + }, + orig_to_new_suffix={ + "patch_embedding.weight": "patch_embedding._linear.weight", + }, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: MllamaConfig = vllm_config.model_config.hf_config @@ -1479,55 +1575,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - updated_params: set[str] = set() - for name, loaded_weight in weights: - if 'patch_embedding.weight' in name: - name = name.replace('patch_embedding.weight', - 'patch_embedding._linear.weight') - loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - updated_params.add(scale_name) - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - updated_params.add(name) - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - orig_name = name - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - logger.debug("Missing name %s, orig name %s", name, - orig_name) - continue - - param = params_dict.pop(name) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - updated_params.add(name) - return updated_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index cc2cebe4a4a37..103a267c41f59 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -24,8 +24,9 @@ from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vision_encoder_info logger = init_logger(__name__) @@ -227,6 +228,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ], } + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -395,4 +405,4 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)