Enable ModelOpt Llama4 fp8 checkpoint deployment (#20419)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu 2025-07-11 23:07:16 -07:00 committed by GitHub
parent 5de8d9f111
commit 4afe687a82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 501 additions and 35 deletions

View File

@ -81,6 +81,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
def uses_weight_scale_2_pattern(self) -> bool:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return False
@staticmethod
def maybe_make_prepare_finalize(
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
@ -1081,12 +1091,23 @@ class FusedMoE(torch.nn.Module):
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name
or 'input_scale' in weight_name):
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
# Determine per-tensor weight scale patterns based on variant
# Use the dedicated method instead of brittle string matching
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
)
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
per_tensor_conditions = (
"weight_scale_2" in weight_name if uses_weight_scale_2 else
"weight_scale" in weight_name) or "input_scale" in weight_name
if per_tensor_conditions:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
@ -1558,3 +1579,7 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]

View File

@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[list[str]] = None,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change.")
@ -69,6 +73,11 @@ class ModelOptFp8Config(QuantizationConfig):
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
kv_cache_quant_method = cls.get_from_keys(
config, ["quantization"]).get("kv_cache_quant_algo")
exclude_modules = cls.get_from_keys(
config, ["quantization"]).get("exclude_modules")
if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the "
@ -76,27 +85,51 @@ class ModelOptFp8Config(QuantizationConfig):
"quant configuration.")
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
return cls(is_checkpoint_fp8_serialized)
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
exclude_modules)
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the
module name (without the language_model prefix) is in the exclude list.
"""
if self.exclude_modules is None:
return False
# Check if any excluded module matches the prefix
for module in self.exclude_modules:
if (module in prefix
or (prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model."))):
return True
return False
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return None
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
activation scale. Future support might be added for dynamic
activation scale. Future support might be added for dynamic
scales.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn datatype
2. Only support float8_e4m3fn datatype
Args: quant_config: The ModelOpt quantization config.
"""
@ -172,6 +205,223 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias=bias)
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and
activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Use FP8 dtype if checkpoint is serialized
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight_loader = extra_weight_attrs.get("weight_loader")
w13_weight = ModelWeightParameter(
data=torch.empty(num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=weight_dtype),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = ModelWeightParameter(
data=torch.empty(num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=weight_dtype),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
1.0,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Set weight loader attributes for scales
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer.w13_weight = Parameter(layer.w13_weight.data,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
from vllm._custom_ops import scaled_fp8_quant
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize)
# Handle scale parameters
if hasattr(layer,
"w13_weight_scale") and layer.w13_weight_scale is not None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2:
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
intermediate_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][start:start +
intermediate_size, :],
_,
) = scaled_fp8_quant(dq_weight,
max_w13_scales[expert_id])
start += intermediate_size
# Update the scale parameter to be per-expert
layer.w13_weight_scale = Parameter(max_w13_scales,
requires_grad=False)
else:
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
requires_grad=False)
if hasattr(layer,
"w2_weight_scale") and layer.w2_weight_scale is not None:
layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
requires_grad=False)
# Input scales must be equal for each expert in fp8 MoE layers.
if hasattr(layer,
"w13_input_scale") and layer.w13_input_scale is not None:
layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
requires_grad=False)
if hasattr(layer,
"w2_input_scale") and layer.w2_input_scale is not None:
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class ModelOptNvFp4Config(QuantizationConfig):
"""Config class for ModelOpt FP4."""
@ -274,7 +524,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
class ModelOptNvFp4LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
input_scale: torch.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
@ -455,7 +705,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization.
Args:
Args:
quant_config: NVFP4 Quant Config
"""
@ -472,6 +722,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" above.")
def uses_weight_scale_2_pattern(self) -> bool:
"""
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
"""
return True
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

View File

@ -762,6 +762,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
modelopt_scale_names = [
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
]
# Also support qkv_proj scale parameters (from stacked parameter processing)
qkv_proj_scale_names = [
".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale"
]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
if any(mo_scale_name in name
@ -769,6 +773,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}")
elif any(qkv_scale_name in name
for qkv_scale_name in qkv_proj_scale_names):
# Handle qkv_proj scale parameters
remapped_name = name.replace(
f".self_attn.qkv_proj{scale_name}",
f".self_attn.attn{scale_name}")
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:

View File

@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@ -432,12 +433,24 @@ class Llama4Model(LlamaModel):
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or "experts" in name:
continue
name = name.replace(weight_name, param_name)
# This check is for ModelOpt ckpts with kv cache quant enabled
if not (name.endswith(
(".k_scale", ".v_scale")) and "self_attn" in name):
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale") and "expert" not in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
@ -452,6 +465,44 @@ class Llama4Model(LlamaModel):
if not moe_loaded:
if is_pp_missing_parameter(name, self):
continue
# Handle flat expert scale parameters that
# don't match per-expert patterns
if ("experts." in name and ("w13_input_scale" in name
or "w13_weight_scale" in name
or "w2_input_scale" in name
or "w2_weight_scale" in name)):
# These are flat expert scales that apply to all experts
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# Check for MoE-specific loading support via
# attribute instead of expensive runtime reflection
supports_moe = getattr(weight_loader,
'supports_moe_loading', False)
if supports_moe:
# This is a MoE weight loader
if "w13_" in name:
shard_id = "w1"
elif "w2_" in name:
shard_id = "w2"
else:
shard_id = "w1"
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=0)
else:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

View File

@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
@classmethod
@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
format."""
if name.startswith("model."):
# Handle expert scale parameters with flat naming
if "feed_forward.experts." in name and ("_input_scale" in name or
"_weight_scale" in name):
renamed = name.replace("model.", "language_model.model.", 1)
# Map checkpoint naming to vLLM's expected naming
if "down_proj_input_scale" in renamed:
return renamed.replace("down_proj_input_scale",
"w2_input_scale")
elif "down_proj_weight_scale" in renamed:
return renamed.replace("down_proj_weight_scale",
"w2_weight_scale")
elif "gate_up_proj_input_scale" in renamed:
return renamed.replace("gate_up_proj_input_scale",
"w13_input_scale")
elif "gate_up_proj_weight_scale" in renamed:
return renamed.replace("gate_up_proj_weight_scale",
"w13_weight_scale")
return renamed
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
# Handle attention scale parameters
elif "self_attn." in name and (".k_scale" in name
or ".v_scale" in name):
renamed = name.replace("model.", "language_model.model.", 1)
if ".k_proj.k_scale" in renamed:
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
elif ".v_proj.v_scale" in renamed:
return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
return renamed
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_weights, other_weights = self.separate_weights(
weights, prefix="language_model.")
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(
language_model_weights)
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
# Standard model.* to language_model.model.* renaming
return name.replace("model.", "language_model.model.", 1)
elif name.startswith("lm_head.weight"):
return name.replace("lm_head.weight",
"language_model.lm_head.weight")
return name
def _separate_and_rename_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
"""Rename weights and separate them into language_model and other
weights."""
language_model_weights = []
other_weights = []
for name, weight in weights:
renamed = self._rename_weight_for_modelopt_checkpoint(name)
if renamed.startswith("language_model."):
language_model_weights.append((renamed, weight))
else:
other_weights.append((renamed, weight))
return language_model_weights, other_weights
def _handle_expert_scale_broadcasting(
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
"""Handle expert scale parameters that need broadcasting.
ModelOpt checkpoints use a single value tensor scalar for BMM style
experts, vLLM expects the scale to be broadcasted across all experts.
"""
regular_weights = []
expert_scale_weights = []
updated_params = set()
for name, weight in weights:
# Check if this is an expert scale parameter that needs broadcasting
if ("feed_forward.experts." in name and "scale" in name
and ".shared_expert" not in name):
if name in params_dict:
param = params_dict[name]
if (hasattr(param, 'data') and param.data.numel() > 1
and weight.numel() == 1):
# Broadcast single value to all experts
param.data.fill_(weight.item())
updated_params.add(name)
continue
expert_scale_weights.append((name, weight))
else:
regular_weights.append((name, weight))
return regular_weights, expert_scale_weights, updated_params
def _load_other_weights(self, other_weights: Iterable[tuple[str,
torch.Tensor]],
params_dict: dict,
stacked_params_mapping: list) -> set[str]:
"""Load non-language-model weights with stacking support."""
updated_params = set()
if self.use_data_parallel:
other_weights = self._consolidate_qkv_weights(other_weights)
for name, loaded_weight in other_weights:
# Try stacked parameter mapping first
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or self.use_data_parallel:
continue
@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader(param, loaded_weight, shard_id)
break
else:
# Use regular weight loading
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
# Shared expert gate_up_proj stacking
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
# Feed forward gate_up_proj stacking (for non-MoE layers if any)
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
# Separate and rename weights
language_model_weights, other_weights = (
self._separate_and_rename_weights(weights))
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights,
params_dict))
updated_params.update(updated_params_from_experts)
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(regular_weights)
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
if expert_scale_weights:
loaded_expert_scale_params = loader.load_weights(
expert_scale_weights)
if loaded_expert_scale_params:
updated_params.update(loaded_expert_scale_params)
updated_params.update(
self._load_other_weights(other_weights, params_dict,
stacked_params_mapping))
return updated_params