[Misc] Gemma3ForConditionalGeneration supports LoRA (#14797)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-03-14 16:07:30 +08:00 committed by GitHub
parent 54cc46f3eb
commit 989ecd2007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
@ -23,7 +24,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@ -371,8 +373,8 @@ class Gemma3MultiModalProjector(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -614,3 +616,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")