[NVIDIA] Fix Llama4 Scout FP4 functionality issues (#21499)

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
Po-Han Huang (NVIDIA) 2025-07-30 22:33:40 +08:00 committed by GitHub
parent 8f4a1c9a04
commit ff08e51940
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 218 additions and 69 deletions

View File

@ -874,6 +874,14 @@ class FusedMoE(torch.nn.Module):
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_w13_weight_scale(self, shard_dim: int,
loaded_weight: torch.Tensor,
param: torch.Tensor, tp_rank: int):
shard_size = param.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
param.copy_(loaded_weight)
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
@ -1123,7 +1131,12 @@ class FusedMoE(torch.nn.Module):
"weight_scale_2" in weight_name if uses_weight_scale_2 else
"weight_scale" in weight_name) or "input_scale" in weight_name
if per_tensor_conditions:
if "w13_weight_scale" in weight_name:
self._load_w13_weight_scale(shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
tp_rank=self.tp_rank)
elif per_tensor_conditions:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,

View File

@ -778,8 +778,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert (layer.weight_scale.shape[1] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)

View File

@ -342,34 +342,94 @@ class Llama4Model(LlamaModel):
expert_params_mapping: list[tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
"""
Load MoE expert weights.
Args:
name: The name of the weight to load.
loaded_weight: The weight to load.
params_dict: The dictionary of module parameters.
loaded_params: The set of already loaded parameters.
expert_params_mapping: The mapping of expert parameters. Must be
generated by FusedMoE.make_expert_params_mapping().
fused: Whether the expert weights are fused into a single weight
tensor or are separate weight tensors for each expert.
When fused is True, loaded_weight should have shape of:
[num_experts, hidden_in, hidden_out] for gate/up/down proj and
[hidden_out, hidden_in] for the others like router.
When fused is False, loaded_weight should have shape of:
[hidden_out, hidden_in].
Returns:
True if loaded_weight is one of MoE weights and the MoE expert
weights are loaded successfully, False otherwise.
"""
# Whether the MoE expert weights are loaded successfully.
expert_param_loaded = False
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-1)
# If fused is True, the loaded weight is in the layout of:
# [num_experts, hidden_in, hidden_out], so we must transpose the last
# two dimensions to match the expected layout of the parameters.
if fused and loaded_weight.ndim == 3:
loaded_weight = loaded_weight.transpose(-1, -2)
# If the gate_proj and up_proj weights are fused into a single
# weight tensor, we need to split the weight tensor into a tuple
# of two weight tensors along the hidden_out dimension.
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
# Iterate over all the expert parameters and load the weights if we find
# a match in weight name.
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
# Get a view of the loaded_weight to avoid modifying the original
# one across iterations.
new_loaded_weight = loaded_weight
# If expert weights are fused into a single weight tensor, remove
# the expert index from the expected weight name.
if fused:
# The string between e_str and proj_str is the expert index.
e_str, _, proj_str, _ = weight_name.split('.')
weight_name = f"{e_str}.{proj_str}"
param_name = f"{param_name}weight"
# Skip if the current weight is not one of the MoE weights.
if weight_name not in name:
continue
# Replace the weight name with the parameter name.
full_param_name = name.replace(weight_name, param_name)
# Skip layers on other devices.
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if is_pp_missing_parameter(name, self):
continue
# Skip if the current weight is for the bias.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[full_param_name]
weight_loader = param.weight_loader
if fused:
# If the parameter is for w13 together, the corresponding weight
# will be a tuple, so we must select the correct weight
# depending on the shard id, which is either "w1" or "w3".
if "w13" in full_param_name:
assert shard_id in ["w1", "w3"]
shard_idx = 0 if shard_id == "w1" else 1
new_loaded_weight = new_loaded_weight[shard_idx]
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
# If EP (expert parallel) is enabled, update expert_id to the
# starting expert index for the current EP rank and extract the
# corresponding expert weights.
layer_idx = extract_layer_index(name)
# EP mapping
expert_map = self.layers[
layer_idx].feed_forward.experts.expert_map
if expert_map is not None:
@ -382,6 +442,9 @@ class Llama4Model(LlamaModel):
else:
# TODO: add EP support for non fused weights
pass
# Load the weight into the module parameter with corresponding
# shard id and expert id.
weight_loader(param,
new_loaded_weight,
full_param_name,
@ -390,10 +453,13 @@ class Llama4Model(LlamaModel):
loaded_params.add(full_param_name)
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
# Name mapping from the parameter name to the shard name and
# corresponding shard id.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@ -402,26 +468,43 @@ class Llama4Model(LlamaModel):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
# Indicate whether the expert weights are fused into a single weight
# tensor.
fused_experts_params = False
# Expert parameter mapping for the case where the expert weights are
# not fused into a single weight tensor.
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts)
# Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor.
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
# All the module parameters.
params_dict = dict(self.named_parameters())
# The module parameters that have been loaded.
loaded_params: set[str] = set()
# Iterate over all the weights and load them into module parameters.
for name, loaded_weight in weights:
# If the name contains "experts.gate_up_proj" or "experts.down_proj"
# without the expert indices, it means the expert weights are fused
# into a single weight tensor across all experts.
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
# If kv cache quantization scales exist and the weight name
# corresponds to one of the kv cache quantization scales, load
# them.
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
@ -430,84 +513,119 @@ class Llama4Model(LlamaModel):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Iterate over stacked_params_mapping to check if the current weight
# is one of the stacked parameters. If so, load the weight with the
# corresponding shard id. Note that MoE weights are handled
# separately in the else block.
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip if the current weight is not one of the stacked
# parameters or if the current weight is a MoE weight.
if weight_name not in name or "experts" in name:
continue
# This check is for ModelOpt ckpts with kv cache quant enabled
# For ModelOpt checkpoints, we need to rename the self_attn
# weight/weight_scale names except for kv cache scales.
if not (name.endswith(
(".k_scale", ".v_scale")) and "self_attn" in name):
name = name.replace(weight_name, param_name)
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale") and "expert" not in name:
# Remapping the name of FP8 kv-scale.
# Remap kv cache scale names for ModelOpt checkpoints.
# TODO: ModelOpt should implement get_cache_scale() such that
# kv cache scale name remapping can be done there.
if name.endswith("scale"):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# Load the weight into the module parameter with corresponding
# shard id and exit the for loop and the else block.
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
# Handle normal (non-stacked) weights and MoE weights.
else:
moe_loaded = self.load_moe_expert_weights(
name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params)
# First, try to load MoE weights using load_moe_expert_weights.
# If successful, move on to next loaded weight.
if self.load_moe_expert_weights(name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params):
continue
if not moe_loaded:
if is_pp_missing_parameter(name, self):
continue
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if is_pp_missing_parameter(name, self):
continue
# Handle flat expert scale parameters that
# don't match per-expert patterns
if ("experts." in name and ("w13_input_scale" in name
or "w13_weight_scale" in name
or "w2_input_scale" in name
or "w2_weight_scale" in name)):
# These are flat expert scales that apply to all experts
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# Check for MoE-specific loading support via
# attribute instead of expensive runtime reflection
supports_moe = getattr(weight_loader,
'supports_moe_loading', False)
if supports_moe:
# This is a MoE weight loader
if "w13_" in name:
shard_id = "w1"
elif "w2_" in name:
shard_id = "w2"
else:
shard_id = "w1"
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=0)
else:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
# Handle flat expert scale parameters that don't match
# per-expert patterns, i.e. one weight scale tensor for all
# experts.
scale_names = [
"w13_input_scale", "w13_weight_scale", "w2_input_scale",
"w2_weight_scale"
]
if ("experts." in name and any(scale_name in name
for scale_name in scale_names)):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If weight loader supports special moe loading, use it to
# avoid expensive runtime reflection
if getattr(weight_loader, 'supports_moe_loading', False):
# Map the weight name to the corresponding shard id.
shard_id = "w2" if "w2_" in name else "w1"
# Transpose if weight scales are FP8 block scales with
# three dimensions:
# [num_experts, hidden_in, hidden_out].
if name.endswith("weight_scale") \
and loaded_weight.dtype == torch.float8_e4m3fn \
and loaded_weight.ndim == 3:
loaded_weight = loaded_weight.transpose(-1, -2)
# Load the weight into the module parameter with
# corresponding shard id and expert id.
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=0)
else:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
# Handle normal (non-stacked, non-MoE) weights.
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# Finally, return the set of loaded parameters.
return loaded_params
@ -560,23 +678,43 @@ class Llama4ForCausalLM(LlamaForCausalLM):
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
# Helper function to permute the weight's channels
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
# Calculate the expected shape of the weight.
# Do not rely on w's shape, as it may be in another layout.
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
# If the weight is FP4 packed as uint8, we need to divide attn_out
# by 2.
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
attn_out = attn_out // 2
# If the weight is a weight scale, we need to divide attn_out by
# block size, which is currently 16.
elif w.dtype == torch.float8_e4m3fn and is_weight_scale \
and w.shape[1] * 16 == attn_out:
attn_out = attn_out // 16
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif ("wq" in modules or "q_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
# Permute Q/K weights and weight block scales for rotary embedding
is_weight = modules[-1] == "weight"
is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and
loaded_weight.dtype == torch.float8_e4m3fn)
if is_weight or is_nvfp4_weight_scale:
if ("wk" in modules or "k_proj" in modules):
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads,
is_nvfp4_weight_scale)
elif ("wq" in modules or "q_proj" in modules):
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads,
is_nvfp4_weight_scale)
return name, loaded_weight