[Bugfix] Update multimodel models mapping to fit new checkpoint after Transformers v4.52 (#19151)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-06-17 23:58:38 +08:00 committed by GitHub
parent 5a1c2e15d8
commit ca94d7fa00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 304 additions and 75 deletions

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import pytest
import torch
import transformers
from transformers import AutoConfig, PreTrainedModel
from vllm.config import ModelConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.config import try_get_safetensors_metadata
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
"""Create weights from safetensors checkpoint metadata"""
metadata = try_get_safetensors_metadata(repo)
weight_names = list(metadata.weight_map.keys())
with torch.device('meta'):
return ((name, torch.empty(0)) for name in weight_names)
def create_model_dummy_weights(
repo: str,
model_arch: str,
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create weights from a dummy meta deserialized hf model with name conversion
"""
model_cls: PreTrainedModel = getattr(transformers, model_arch)
config = AutoConfig.from_pretrained(repo)
with torch.device("meta"):
model: PreTrainedModel = model_cls._from_config(config)
return model.named_parameters()
def model_architectures_for_test() -> list[str]:
arch_to_test = list[str]()
for model_arch, info in _MULTIMODAL_EXAMPLE_MODELS.items():
if not info.trust_remote_code and hasattr(transformers, model_arch):
model_cls: PreTrainedModel = getattr(transformers, model_arch)
if getattr(model_cls, "_checkpoint_conversion_mapping", None):
arch_to_test.append(model_arch)
return arch_to_test
@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", model_architectures_for_test())
def test_hf_model_weights_mapper(model_arch: str):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
model_id = model_info.default
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="auto",
revision=None,
hf_overrides=model_info.hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
original_weights = create_repo_dummy_weights(model_id)
hf_converted_weights = create_model_dummy_weights(model_id, model_arch)
mapper: WeightsMapper = model_cls.hf_to_vllm_mapper
mapped_original_weights = mapper.apply(original_weights)
mapped_hf_converted_weights = mapper.apply(hf_converted_weights)
ref_weight_names = set(map(lambda x: x[0], mapped_original_weights))
weight_names = set(map(lambda x: x[0], mapped_hf_converted_weights))
weights_missing = ref_weight_names - weight_names
weights_unmapped = weight_names - ref_weight_names
assert (not weights_missing and not weights_unmapped), (
f"Following weights are not mapped correctly: {weights_unmapped}, "
f"Missing expected weights: {weights_missing}.")

View File

@ -486,6 +486,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
""" """
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
# mapping for original checkpoint
"language_model.model": "language_model", "language_model.model": "language_model",
"language_model.lm_head": "lm_head", "language_model.lm_head": "lm_head",
}, },

View File

@ -32,8 +32,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
maybe_prefix, merge_multimodal_embeddings) init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
class AyaVisionImagePixelInputs(TypedDict): class AyaVisionImagePixelInputs(TypedDict):
@ -292,6 +293,15 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config: AyaVisionConfig = vllm_config.model_config.hf_config config: AyaVisionConfig = vllm_config.model_config.hf_config
@ -323,7 +333,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,

View File

@ -42,7 +42,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
@ -245,6 +245,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
dummy_inputs=FuyuDummyInputsBuilder) dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.": "vision_embed_tokens.",
"model.language_model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config

View File

@ -35,8 +35,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
maybe_prefix, merge_multimodal_embeddings) init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -471,6 +472,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
], ],
} }
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -697,7 +707,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """

View File

@ -40,8 +40,9 @@ from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
maybe_prefix, merge_multimodal_embeddings) init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
@ -499,6 +500,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
"gate_up_proj": ["gate_proj", "up_proj"] "gate_up_proj": ["gate_proj", "up_proj"]
} }
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
@ -754,7 +764,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
class MantisProcessingInfo(LlavaProcessingInfo): class MantisProcessingInfo(LlavaProcessingInfo):

View File

@ -26,8 +26,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava) LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
init_vllm_registered_model, maybe_prefix) flatten_bn, init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TypedDict):
@ -205,6 +205,16 @@ class LlavaNextMultiModalProcessor(
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -583,4 +593,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -29,8 +29,9 @@ from vllm.utils import is_list_of
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings) init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
@ -270,6 +271,16 @@ class LlavaNextMultiModalProjector(nn.Module):
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -468,4 +479,4 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
# This model doesn't support images for now # This model doesn't support images for now
ignore_unexpected_prefixes=["image_newline"], ignore_unexpected_prefixes=["image_newline"],
) )
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -30,8 +30,9 @@ from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
LlavaNextProcessingInfo) LlavaNextProcessingInfo)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
maybe_prefix, merge_multimodal_embeddings) init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 16 _MAX_FRAMES_PER_VIDEO = 16
@ -428,6 +429,16 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -954,4 +965,4 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -36,8 +36,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
maybe_prefix, merge_multimodal_embeddings) init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
@ -389,6 +390,15 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
"gate_up_proj": ["gate_proj", "up_proj"] "gate_up_proj": ["gate_proj", "up_proj"]
} }
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
@ -592,7 +602,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """

View File

@ -67,7 +67,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from .clip import CLIPMLP from .clip import CLIPMLP
from .interfaces import SupportsMultiModal, SupportsV0Only from .interfaces import SupportsMultiModal, SupportsV0Only
from .llama import LlamaDecoderLayer, LlamaMLP from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
logger = init_logger(__name__) logger = init_logger(__name__)
@ -790,6 +790,36 @@ class MllamaVisionModel(nn.Module):
dim=-1) dim=-1)
return hidden_state return hidden_state
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding._linear.weight' in name:
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
class MllamaTextRMSNorm(nn.Module): class MllamaTextRMSNorm(nn.Module):
@ -1132,6 +1162,7 @@ class MllamaForCausalLM(nn.Module):
config = vllm_config.model_config.hf_config.text_config config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = MllamaTextModel(vllm_config=vllm_config, self.model = MllamaTextModel(vllm_config=vllm_config,
@ -1167,6 +1198,58 @@ class MllamaForCausalLM(nn.Module):
) )
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
info=MllamaProcessingInfo, info=MllamaProcessingInfo,
@ -1178,6 +1261,19 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
"gate_up_proj": ["gate_proj", "up_proj"] "gate_up_proj": ["gate_proj", "up_proj"]
} }
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.vision_model.": "vision_model.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.language_model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
},
orig_to_new_suffix={
"patch_embedding.weight": "patch_embedding._linear.weight",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config: MllamaConfig = vllm_config.model_config.hf_config config: MllamaConfig = vllm_config.model_config.hf_config
@ -1479,55 +1575,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ loader = AutoWeightsLoader(self)
# (param_name, shard_name, shard_id) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """

View File

@ -24,8 +24,9 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
maybe_prefix, merge_multimodal_embeddings) init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
logger = init_logger(__name__) logger = init_logger(__name__)
@ -227,6 +228,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
], ],
} }
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -395,4 +405,4 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)