mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:16:45 +08:00
[Bugfix] Update multimodel models mapping to fit new checkpoint after Transformers v4.52 (#19151)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
5a1c2e15d8
commit
ca94d7fa00
86
tests/models/multimodal/test_mapping.py
Normal file
86
tests/models/multimodal/test_mapping.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig, PreTrainedModel
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.transformers_utils.config import try_get_safetensors_metadata
|
||||
|
||||
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""Create weights from safetensors checkpoint metadata"""
|
||||
metadata = try_get_safetensors_metadata(repo)
|
||||
weight_names = list(metadata.weight_map.keys())
|
||||
with torch.device('meta'):
|
||||
return ((name, torch.empty(0)) for name in weight_names)
|
||||
|
||||
|
||||
def create_model_dummy_weights(
|
||||
repo: str,
|
||||
model_arch: str,
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Create weights from a dummy meta deserialized hf model with name conversion
|
||||
"""
|
||||
model_cls: PreTrainedModel = getattr(transformers, model_arch)
|
||||
config = AutoConfig.from_pretrained(repo)
|
||||
with torch.device("meta"):
|
||||
model: PreTrainedModel = model_cls._from_config(config)
|
||||
return model.named_parameters()
|
||||
|
||||
|
||||
def model_architectures_for_test() -> list[str]:
|
||||
arch_to_test = list[str]()
|
||||
for model_arch, info in _MULTIMODAL_EXAMPLE_MODELS.items():
|
||||
if not info.trust_remote_code and hasattr(transformers, model_arch):
|
||||
model_cls: PreTrainedModel = getattr(transformers, model_arch)
|
||||
if getattr(model_cls, "_checkpoint_conversion_mapping", None):
|
||||
arch_to_test.append(model_arch)
|
||||
return arch_to_test
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model_arch", model_architectures_for_test())
|
||||
def test_hf_model_weights_mapper(model_arch: str):
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
model_id = model_info.default
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_info.tokenizer or model_id,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
seed=0,
|
||||
dtype="auto",
|
||||
revision=None,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
original_weights = create_repo_dummy_weights(model_id)
|
||||
hf_converted_weights = create_model_dummy_weights(model_id, model_arch)
|
||||
mapper: WeightsMapper = model_cls.hf_to_vllm_mapper
|
||||
|
||||
mapped_original_weights = mapper.apply(original_weights)
|
||||
mapped_hf_converted_weights = mapper.apply(hf_converted_weights)
|
||||
|
||||
ref_weight_names = set(map(lambda x: x[0], mapped_original_weights))
|
||||
weight_names = set(map(lambda x: x[0], mapped_hf_converted_weights))
|
||||
|
||||
weights_missing = ref_weight_names - weight_names
|
||||
weights_unmapped = weight_names - ref_weight_names
|
||||
assert (not weights_missing and not weights_unmapped), (
|
||||
f"Following weights are not mapped correctly: {weights_unmapped}, "
|
||||
f"Missing expected weights: {weights_missing}.")
|
||||
@ -486,6 +486,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
"""
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
# mapping for original checkpoint
|
||||
"language_model.model": "language_model",
|
||||
"language_model.lm_head": "lm_head",
|
||||
},
|
||||
|
||||
@ -32,8 +32,9 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class AyaVisionImagePixelInputs(TypedDict):
|
||||
@ -292,6 +293,15 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: AyaVisionConfig = vllm_config.model_config.hf_config
|
||||
@ -323,7 +333,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
||||
pixel_values: torch.Tensor,
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
@ -245,6 +245,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
dummy_inputs=FuyuDummyInputsBuilder)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embed_tokens.": "vision_embed_tokens.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -35,8 +35,9 @@ from vllm.sequence import IntermediateTensors
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -471,6 +472,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -697,7 +707,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
||||
@ -40,8 +40,9 @@ from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
@ -499,6 +500,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -754,7 +764,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
class MantisProcessingInfo(LlavaProcessingInfo):
|
||||
|
||||
@ -26,8 +26,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
|
||||
LlavaDummyInputsBuilder, LlavaLikeConfig,
|
||||
LlavaMultiModalProjector, init_vision_tower_for_llava)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
|
||||
flatten_bn, init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
@ -205,6 +205,16 @@ class LlavaNextMultiModalProcessor(
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -583,4 +593,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
@ -29,8 +29,9 @@ from vllm.utils import is_list_of
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
@ -270,6 +271,16 @@ class LlavaNextMultiModalProjector(nn.Module):
|
||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -468,4 +479,4 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# This model doesn't support images for now
|
||||
ignore_unexpected_prefixes=["image_newline"],
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
@ -30,8 +30,9 @@ from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
|
||||
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
|
||||
LlavaNextProcessingInfo)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
@ -428,6 +429,16 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -954,4 +965,4 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
@ -36,8 +36,9 @@ from vllm.sequence import IntermediateTensors
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
@ -389,6 +390,15 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -592,7 +602,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
||||
@ -67,7 +67,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from .clip import CLIPMLP
|
||||
from .interfaces import SupportsMultiModal, SupportsV0Only
|
||||
from .llama import LlamaDecoderLayer, LlamaMLP
|
||||
from .utils import maybe_prefix
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -790,6 +790,36 @@ class MllamaVisionModel(nn.Module):
|
||||
dim=-1)
|
||||
return hidden_state
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if 'patch_embedding._linear.weight' in name:
|
||||
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
updated_params.add(name)
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict.pop(name)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(name)
|
||||
return updated_params
|
||||
|
||||
|
||||
class MllamaTextRMSNorm(nn.Module):
|
||||
|
||||
@ -1132,6 +1162,7 @@ class MllamaForCausalLM(nn.Module):
|
||||
|
||||
config = vllm_config.model_config.hf_config.text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = MllamaTextModel(vllm_config=vllm_config,
|
||||
@ -1167,6 +1198,58 @@ class MllamaForCausalLM(nn.Module):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if 'patch_embedding.weight' in name:
|
||||
name = name.replace('patch_embedding.weight',
|
||||
'patch_embedding._linear.weight')
|
||||
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
updated_params.add(name)
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
orig_name = name
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
logger.debug("Missing name %s, orig name %s", name,
|
||||
orig_name)
|
||||
continue
|
||||
|
||||
param = params_dict.pop(name)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(name)
|
||||
return updated_params
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
|
||||
info=MllamaProcessingInfo,
|
||||
@ -1178,6 +1261,19 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.vision_model.": "vision_model.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
"patch_embedding.weight": "patch_embedding._linear.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: MllamaConfig = vllm_config.model_config.hf_config
|
||||
@ -1479,55 +1575,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if 'patch_embedding.weight' in name:
|
||||
name = name.replace('patch_embedding.weight',
|
||||
'patch_embedding._linear.weight')
|
||||
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
updated_params.add(name)
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
orig_name = name
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
logger.debug("Missing name %s, orig name %s", name,
|
||||
orig_name)
|
||||
continue
|
||||
|
||||
param = params_dict.pop(name)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(name)
|
||||
return updated_params
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
||||
@ -24,8 +24,9 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -227,6 +228,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -395,4 +405,4 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user