[LoRA] Support dynamically initialize packed_modules_mapping for VLM with arbitrary components (#18987)

Signed-off-by: isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-06-01 11:06:57 +08:00 committed by GitHub
parent 6aa8f9a4e7
commit a35ca765a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 32 additions and 38 deletions

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import math
import os
from collections.abc import Sequence
@ -34,6 +33,7 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -364,8 +364,8 @@ class LoRAModelManager(AdapterModelManager):
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import copy
import fnmatch
import glob
import itertools
@ -36,7 +35,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import (get_packed_modules_mapping,
set_weight_attrs)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -420,8 +420,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
self.is_pool_model=is_pooling_model(model)
self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.

View File

@ -415,6 +415,10 @@ class InternVisionEncoder(nn.Module):
class InternVisionModel(nn.Module):
packed_modules_mapping = {
"qkv": ["qkv"],
}
def __init__(
self,
config: PretrainedConfig,

View File

@ -1019,15 +1019,6 @@ class InternVLMultiModalProcessor(
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
packed_modules_mapping = {
"wqkv": ["wqkv"],
"qkv": ["qkv"],
"gate_up_proj": [
"w1",
"w3",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()

View File

@ -821,17 +821,6 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={

View File

@ -1069,17 +1069,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
dummy_inputs=Qwen2VLDummyInputsBuilder)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Utils for model executor."""
import copy
from typing import Any, Optional
import torch
@ -51,3 +52,23 @@ def _make_synced_weight_loader(original_weight_loader):
torch._sync(param)
return _synced_weight_loader
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {}))
# don't infer mapping if the model has defined it explicitly.
if parent_map:
return parent_map
# We only check main components instead of whole model submodules
for child in model.children():
child_map = getattr(child, "packed_modules_mapping", {})
if any((k in parent_map and parent_map[k] != v)
for k, v in child_map.items()):
raise ValueError(
f"Can't update {type(model).__name__}'s packed_modules_mapping "
f"safely because of conflicts from {type(child).__name__}.")
else:
parent_map.update(child_map)
return parent_map