diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 14e741f32258..e25a104d822a 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -35,6 +35,7 @@ from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -45,6 +46,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -68,11 +70,15 @@ from .interfaces import ( MixtureOfExperts, MultiModalEmbeddings, SupportsEagle3, + SupportsLoRA, SupportsMultiModal, SupportsPP, ) from .llama4 import Llama4ForCausalLM -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import ( + AutoWeightsLoader, + maybe_prefix, +) from .vision import run_dp_sharded_vision_model @@ -724,7 +730,12 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): dummy_inputs=Mllama4DummyInputsBuilder, ) class Llama4ForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3 + nn.Module, + SupportsMultiModal, + SupportsPP, + MixtureOfExperts, + SupportsEagle3, + SupportsLoRA, ): merge_by_field_config = True @@ -1067,6 +1078,17 @@ class Llama4ForConditionalGeneration( return updated_params + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.text_config.num_local_experts, + num_redundant_experts=self.num_redundant_experts, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -1113,3 +1135,13 @@ class Llama4ForConditionalGeneration( ) return updated_params + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector.", + tower_model="vision_model.", + )