mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:55:01 +08:00
LLaMA4 LoRA Adapter Enablement (#28602)
Signed-off-by: Fardin Hoque <kfhfar@amazon.com> Co-authored-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
parent
9261eb3dc1
commit
964d65deed
@ -35,6 +35,7 @@ from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -45,6 +46,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.utils import initialize_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@ -68,11 +70,15 @@ from .interfaces import (
|
||||
MixtureOfExperts,
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@ -724,7 +730,12 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
dummy_inputs=Mllama4DummyInputsBuilder,
|
||||
)
|
||||
class Llama4ForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
MixtureOfExperts,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
@ -1067,6 +1078,17 @@ class Llama4ForConditionalGeneration(
|
||||
|
||||
return updated_params
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.text_config.num_local_experts,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@ -1113,3 +1135,13 @@ class Llama4ForConditionalGeneration(
|
||||
)
|
||||
|
||||
return updated_params
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector.",
|
||||
tower_model="vision_model.",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user