Enable ModelOpt Llama4 fp8 checkpoint deployment (#20419)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu 2025-07-11 23:07:16 -07:00 committed by GitHub
parent 5de8d9f111
commit 4afe687a82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 501 additions and 35 deletions

View File

@ -81,6 +81,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
def uses_weight_scale_2_pattern(self) -> bool:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return False
@staticmethod @staticmethod
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
@ -1081,12 +1091,23 @@ class FusedMoE(torch.nn.Module):
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name: if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name # Determine per-tensor weight scale patterns based on variant
or 'input_scale' in weight_name): # Use the dedicated method instead of brittle string matching
self._load_per_tensor_weight_scale(shard_id=shard_id, uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
param=param, )
loaded_weight=loaded_weight,
expert_id=expert_id) # For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
per_tensor_conditions = (
"weight_scale_2" in weight_name if uses_weight_scale_2 else
"weight_scale" in weight_name) or "input_scale" in weight_name
if per_tensor_conditions:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name: elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale( self._load_model_weight_or_group_weight_scale(
shard_id=shard_id, shard_id=shard_id,
@ -1558,3 +1579,7 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]

View File

@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
def __init__( def __init__(
self, self,
is_checkpoint_fp8_serialized: bool = False, is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[list[str]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
if is_checkpoint_fp8_serialized: if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change.") " the format is experimental and could change.")
@ -69,6 +73,11 @@ class ModelOptFp8Config(QuantizationConfig):
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"]) quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"] quant_method = quant_config["quant_algo"]
kv_cache_quant_method = cls.get_from_keys(
config, ["quantization"]).get("kv_cache_quant_algo")
exclude_modules = cls.get_from_keys(
config, ["quantization"]).get("exclude_modules")
if quant_method not in QUANT_ALGOS: if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the " " quantizations in vLLM. Please check the "
@ -76,15 +85,39 @@ class ModelOptFp8Config(QuantizationConfig):
"quant configuration.") "quant configuration.")
is_checkpoint_fp8_serialized = ("FP8" in quant_method) is_checkpoint_fp8_serialized = ("FP8" in quant_method)
return cls(is_checkpoint_fp8_serialized) return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
exclude_modules)
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the
module name (without the language_model prefix) is in the exclude list.
"""
if self.exclude_modules is None:
return False
# Check if any excluded module matches the prefix
for module in self.exclude_modules:
if (module in prefix
or (prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model."))):
return True
return False
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]: prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self) return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return None return None
@ -172,6 +205,223 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias=bias) bias=bias)
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and
activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Use FP8 dtype if checkpoint is serialized
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight_loader = extra_weight_attrs.get("weight_loader")
w13_weight = ModelWeightParameter(
data=torch.empty(num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=weight_dtype),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = ModelWeightParameter(
data=torch.empty(num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=weight_dtype),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
1.0,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Set weight loader attributes for scales
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer.w13_weight = Parameter(layer.w13_weight.data,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
from vllm._custom_ops import scaled_fp8_quant
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize)
# Handle scale parameters
if hasattr(layer,
"w13_weight_scale") and layer.w13_weight_scale is not None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2:
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
intermediate_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][start:start +
intermediate_size, :],
_,
) = scaled_fp8_quant(dq_weight,
max_w13_scales[expert_id])
start += intermediate_size
# Update the scale parameter to be per-expert
layer.w13_weight_scale = Parameter(max_w13_scales,
requires_grad=False)
else:
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
requires_grad=False)
if hasattr(layer,
"w2_weight_scale") and layer.w2_weight_scale is not None:
layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
requires_grad=False)
# Input scales must be equal for each expert in fp8 MoE layers.
if hasattr(layer,
"w13_input_scale") and layer.w13_input_scale is not None:
layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
requires_grad=False)
if hasattr(layer,
"w2_input_scale") and layer.w2_input_scale is not None:
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class ModelOptNvFp4Config(QuantizationConfig): class ModelOptNvFp4Config(QuantizationConfig):
"""Config class for ModelOpt FP4.""" """Config class for ModelOpt FP4."""
@ -472,6 +722,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" quantization. Please use Blackwell and" " quantization. Please use Blackwell and"
" above.") " above.")
def uses_weight_scale_2_pattern(self) -> bool:
"""
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
"""
return True
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):

View File

@ -762,6 +762,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
modelopt_scale_names = [ modelopt_scale_names = [
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
] ]
# Also support qkv_proj scale parameters (from stacked parameter processing)
qkv_proj_scale_names = [
".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale"
]
for scale_name in possible_scale_names: for scale_name in possible_scale_names:
if name.endswith(scale_name): if name.endswith(scale_name):
if any(mo_scale_name in name if any(mo_scale_name in name
@ -769,6 +773,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
remapped_name = name.replace( remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}", f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}") f".self_attn.attn{scale_name}")
elif any(qkv_scale_name in name
for qkv_scale_name in qkv_proj_scale_names):
# Handle qkv_proj scale parameters
remapped_name = name.replace(
f".self_attn.qkv_proj{scale_name}",
f".self_attn.attn{scale_name}")
else: else:
remapped_name = name.replace(scale_name, f".attn{scale_name}") remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict: if remapped_name not in params_dict:

View File

@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@ -432,12 +433,24 @@ class Llama4Model(LlamaModel):
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or "experts" in name: if weight_name not in name or "experts" in name:
continue continue
name = name.replace(weight_name, param_name) # This check is for ModelOpt ckpts with kv cache quant enabled
if not (name.endswith(
(".k_scale", ".v_scale")) and "self_attn" in name):
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
if name.endswith("scale") and "expert" not in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = getattr(param, "weight_loader",
weight_loader(param, loaded_weight, shard_id) default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name) loaded_params.add(name)
break break
else: else:
@ -452,6 +465,44 @@ class Llama4Model(LlamaModel):
if not moe_loaded: if not moe_loaded:
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
# Handle flat expert scale parameters that
# don't match per-expert patterns
if ("experts." in name and ("w13_input_scale" in name
or "w13_weight_scale" in name
or "w2_input_scale" in name
or "w2_weight_scale" in name)):
# These are flat expert scales that apply to all experts
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# Check for MoE-specific loading support via
# attribute instead of expensive runtime reflection
supports_moe = getattr(weight_loader,
'supports_moe_loading', False)
if supports_moe:
# This is a MoE weight loader
if "w13_" in name:
shard_id = "w1"
elif "w2_" in name:
shard_id = "w2"
else:
shard_id = "w1"
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=0)
else:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
} }
@classmethod @classmethod
@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
qkv_weight = torch.cat(weight, dim=0) qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str, def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
torch.Tensor]]) -> set[str]: """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
format."""
if name.startswith("model."):
# Handle expert scale parameters with flat naming
if "feed_forward.experts." in name and ("_input_scale" in name or
"_weight_scale" in name):
renamed = name.replace("model.", "language_model.model.", 1)
# Map checkpoint naming to vLLM's expected naming
if "down_proj_input_scale" in renamed:
return renamed.replace("down_proj_input_scale",
"w2_input_scale")
elif "down_proj_weight_scale" in renamed:
return renamed.replace("down_proj_weight_scale",
"w2_weight_scale")
elif "gate_up_proj_input_scale" in renamed:
return renamed.replace("gate_up_proj_input_scale",
"w13_input_scale")
elif "gate_up_proj_weight_scale" in renamed:
return renamed.replace("gate_up_proj_weight_scale",
"w13_weight_scale")
return renamed
stacked_params_mapping = [ # Handle attention scale parameters
# (param_name, shard_name, shard_id) elif "self_attn." in name and (".k_scale" in name
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"), or ".v_scale" in name):
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"), renamed = name.replace("model.", "language_model.model.", 1)
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"), if ".k_proj.k_scale" in renamed:
] return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
params_dict = dict(self.named_parameters()) elif ".v_proj.v_scale" in renamed:
updated_params: set[str] = set() return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
return renamed
# language_model is an Llama4ForCausalLM instance. We load it's # Standard model.* to language_model.model.* renaming
# using llama4's load_weights routine. return name.replace("model.", "language_model.model.", 1)
language_model_weights, other_weights = self.separate_weights(
weights, prefix="language_model.") elif name.startswith("lm_head.weight"):
loader = AutoWeightsLoader(self) return name.replace("lm_head.weight",
loaded_language_model_params = loader.load_weights( "language_model.lm_head.weight")
language_model_weights)
assert loaded_language_model_params is not None return name
updated_params.update(loaded_language_model_params)
def _separate_and_rename_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
"""Rename weights and separate them into language_model and other
weights."""
language_model_weights = []
other_weights = []
for name, weight in weights:
renamed = self._rename_weight_for_modelopt_checkpoint(name)
if renamed.startswith("language_model."):
language_model_weights.append((renamed, weight))
else:
other_weights.append((renamed, weight))
return language_model_weights, other_weights
def _handle_expert_scale_broadcasting(
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
"""Handle expert scale parameters that need broadcasting.
ModelOpt checkpoints use a single value tensor scalar for BMM style
experts, vLLM expects the scale to be broadcasted across all experts.
"""
regular_weights = []
expert_scale_weights = []
updated_params = set()
for name, weight in weights:
# Check if this is an expert scale parameter that needs broadcasting
if ("feed_forward.experts." in name and "scale" in name
and ".shared_expert" not in name):
if name in params_dict:
param = params_dict[name]
if (hasattr(param, 'data') and param.data.numel() > 1
and weight.numel() == 1):
# Broadcast single value to all experts
param.data.fill_(weight.item())
updated_params.add(name)
continue
expert_scale_weights.append((name, weight))
else:
regular_weights.append((name, weight))
return regular_weights, expert_scale_weights, updated_params
def _load_other_weights(self, other_weights: Iterable[tuple[str,
torch.Tensor]],
params_dict: dict,
stacked_params_mapping: list) -> set[str]:
"""Load non-language-model weights with stacking support."""
updated_params = set()
if self.use_data_parallel: if self.use_data_parallel:
other_weights = self._consolidate_qkv_weights(other_weights) other_weights = self._consolidate_qkv_weights(other_weights)
for name, loaded_weight in other_weights: for name, loaded_weight in other_weights:
# Try stacked parameter mapping first
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or self.use_data_parallel: if weight_name not in name or self.use_data_parallel:
continue continue
@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Use regular weight loading
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
updated_params.add(name) updated_params.add(name)
return updated_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
# Shared expert gate_up_proj stacking
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
# Feed forward gate_up_proj stacking (for non-MoE layers if any)
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
# Separate and rename weights
language_model_weights, other_weights = (
self._separate_and_rename_weights(weights))
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights,
params_dict))
updated_params.update(updated_params_from_experts)
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(regular_weights)
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
if expert_scale_weights:
loaded_expert_scale_params = loader.load_weights(
expert_scale_weights)
if loaded_expert_scale_params:
updated_params.update(loaded_expert_scale_params)
updated_params.update(
self._load_other_weights(other_weights, params_dict,
stacked_params_mapping))
return updated_params return updated_params