[Bugfix] Update multimodel models mapping to fit new checkpoint after Transformers v4.52 (#19151)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-06-17 23:58:38 +08:00 committed by GitHub
parent 5a1c2e15d8
commit ca94d7fa00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 304 additions and 75 deletions

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import pytest
import torch
import transformers
from transformers import AutoConfig, PreTrainedModel
from vllm.config import ModelConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.config import try_get_safetensors_metadata
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
"""Create weights from safetensors checkpoint metadata"""
metadata = try_get_safetensors_metadata(repo)
weight_names = list(metadata.weight_map.keys())
with torch.device('meta'):
return ((name, torch.empty(0)) for name in weight_names)
def create_model_dummy_weights(
repo: str,
model_arch: str,
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create weights from a dummy meta deserialized hf model with name conversion
"""
model_cls: PreTrainedModel = getattr(transformers, model_arch)
config = AutoConfig.from_pretrained(repo)
with torch.device("meta"):
model: PreTrainedModel = model_cls._from_config(config)
return model.named_parameters()
def model_architectures_for_test() -> list[str]:
arch_to_test = list[str]()
for model_arch, info in _MULTIMODAL_EXAMPLE_MODELS.items():
if not info.trust_remote_code and hasattr(transformers, model_arch):
model_cls: PreTrainedModel = getattr(transformers, model_arch)
if getattr(model_cls, "_checkpoint_conversion_mapping", None):
arch_to_test.append(model_arch)
return arch_to_test
@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", model_architectures_for_test())
def test_hf_model_weights_mapper(model_arch: str):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
model_id = model_info.default
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="auto",
revision=None,
hf_overrides=model_info.hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
original_weights = create_repo_dummy_weights(model_id)
hf_converted_weights = create_model_dummy_weights(model_id, model_arch)
mapper: WeightsMapper = model_cls.hf_to_vllm_mapper
mapped_original_weights = mapper.apply(original_weights)
mapped_hf_converted_weights = mapper.apply(hf_converted_weights)
ref_weight_names = set(map(lambda x: x[0], mapped_original_weights))
weight_names = set(map(lambda x: x[0], mapped_hf_converted_weights))
weights_missing = ref_weight_names - weight_names
weights_unmapped = weight_names - ref_weight_names
assert (not weights_missing and not weights_unmapped), (
f"Following weights are not mapped correctly: {weights_unmapped}, "
f"Missing expected weights: {weights_missing}.")

View File

@ -486,6 +486,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
# mapping for original checkpoint
"language_model.model": "language_model",
"language_model.lm_head": "lm_head",
},

View File

@ -32,8 +32,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
class AyaVisionImagePixelInputs(TypedDict):
@ -292,6 +293,15 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: AyaVisionConfig = vllm_config.model_config.hf_config
@ -323,7 +333,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,

View File

@ -42,7 +42,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config.
@ -245,6 +245,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.": "vision_embed_tokens.",
"model.language_model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config

View File

@ -35,8 +35,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -471,6 +472,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -697,7 +707,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""

View File

@ -40,8 +40,9 @@ from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
@ -499,6 +500,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
"gate_up_proj": ["gate_proj", "up_proj"]
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
@ -754,7 +764,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
class MantisProcessingInfo(LlavaProcessingInfo):

View File

@ -26,8 +26,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model, maybe_prefix)
from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
flatten_bn, init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict):
@ -205,6 +205,16 @@ class LlavaNextMultiModalProcessor(
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
@ -583,4 +593,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -29,8 +29,9 @@ from vllm.utils import is_list_of
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
@ -270,6 +271,16 @@ class LlavaNextMultiModalProjector(nn.Module):
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
@ -468,4 +479,4 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
# This model doesn't support images for now
ignore_unexpected_prefixes=["image_newline"],
)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -30,8 +30,9 @@ from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
LlavaNextProcessingInfo)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
@ -428,6 +429,16 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
@ -954,4 +965,4 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -36,8 +36,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
@ -389,6 +390,15 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
"gate_up_proj": ["gate_proj", "up_proj"]
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
@ -592,7 +602,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""

View File

@ -67,7 +67,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal, SupportsV0Only
from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
logger = init_logger(__name__)
@ -790,6 +790,36 @@ class MllamaVisionModel(nn.Module):
dim=-1)
return hidden_state
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding._linear.weight' in name:
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
class MllamaTextRMSNorm(nn.Module):
@ -1132,6 +1162,7 @@ class MllamaForCausalLM(nn.Module):
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(vllm_config=vllm_config,
@ -1167,6 +1198,58 @@ class MllamaForCausalLM(nn.Module):
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
info=MllamaProcessingInfo,
@ -1178,6 +1261,19 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
"gate_up_proj": ["gate_proj", "up_proj"]
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.vision_model.": "vision_model.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.language_model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
},
orig_to_new_suffix={
"patch_embedding.weight": "patch_embedding._linear.weight",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: MllamaConfig = vllm_config.model_config.hf_config
@ -1479,55 +1575,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""

View File

@ -24,8 +24,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
logger = init_logger(__name__)
@ -227,6 +228,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -395,4 +405,4 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)