LLaMA4 LoRA Adapter Enablement (#28602)

Signed-off-by: Fardin Hoque <kfhfar@amazon.com>
Co-authored-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
Fardin Hoque 2025-11-14 10:27:56 -08:00 committed by GitHub
parent 9261eb3dc1
commit 964d65deed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,6 +35,7 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -45,6 +46,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
@ -68,11 +70,15 @@ from .interfaces import (
MixtureOfExperts,
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import (
AutoWeightsLoader,
maybe_prefix,
)
from .vision import run_dp_sharded_vision_model
@ -724,7 +730,12 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
nn.Module,
SupportsMultiModal,
SupportsPP,
MixtureOfExperts,
SupportsEagle3,
SupportsLoRA,
):
merge_by_field_config = True
@ -1067,6 +1078,17 @@ class Llama4ForConditionalGeneration(
return updated_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.text_config.num_local_experts,
num_redundant_experts=self.num_redundant_experts,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@ -1113,3 +1135,13 @@ class Llama4ForConditionalGeneration(
)
return updated_params
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector.",
tower_model="vision_model.",
)