[LoRA] Add LoRA support for InternVL (#18842)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-05-29 16:46:24 +08:00 committed by GitHub
parent 972eddf7c9
commit 34d6c447c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
@ -36,7 +37,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@ -1014,7 +1016,17 @@ class InternVLMultiModalProcessor(
InternVLMultiModalProcessor,
info=InternVLProcessingInfo,
dummy_inputs=InternVLDummyInputsBuilder)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
packed_modules_mapping = {
"wqkv": ["wqkv"],
"qkv": ["qkv"],
"gate_up_proj": [
"w1",
"w3",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
@ -1403,3 +1415,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="mlp1",
tower_model="vision_model")