mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[NVIDIA] Fix Llama4 Scout FP4 functionality issues (#21499)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
8f4a1c9a04
commit
ff08e51940
@ -874,6 +874,14 @@ class FusedMoE(torch.nn.Module):
|
||||
elif shard_id == "w2":
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_w13_weight_scale(self, shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
param: torch.Tensor, tp_rank: int):
|
||||
shard_size = param.shape[shard_dim]
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
param.copy_(loaded_weight)
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
@ -1123,7 +1131,12 @@ class FusedMoE(torch.nn.Module):
|
||||
"weight_scale_2" in weight_name if uses_weight_scale_2 else
|
||||
"weight_scale" in weight_name) or "input_scale" in weight_name
|
||||
|
||||
if per_tensor_conditions:
|
||||
if "w13_weight_scale" in weight_name:
|
||||
self._load_w13_weight_scale(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
param=param,
|
||||
tp_rank=self.tp_rank)
|
||||
elif per_tensor_conditions:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
|
||||
@ -778,8 +778,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
# Swizzle the weight blockscale.
|
||||
# contracting dimension is input dimension
|
||||
# block_size = 16;
|
||||
assert (layer.weight_scale.shape[1] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Block scale must be represented as FP8-E4M3")
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
|
||||
@ -342,34 +342,94 @@ class Llama4Model(LlamaModel):
|
||||
expert_params_mapping: list[tuple[str, str, int, str]],
|
||||
fused: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Load MoE expert weights.
|
||||
|
||||
Args:
|
||||
name: The name of the weight to load.
|
||||
loaded_weight: The weight to load.
|
||||
params_dict: The dictionary of module parameters.
|
||||
loaded_params: The set of already loaded parameters.
|
||||
expert_params_mapping: The mapping of expert parameters. Must be
|
||||
generated by FusedMoE.make_expert_params_mapping().
|
||||
fused: Whether the expert weights are fused into a single weight
|
||||
tensor or are separate weight tensors for each expert.
|
||||
When fused is True, loaded_weight should have shape of:
|
||||
[num_experts, hidden_in, hidden_out] for gate/up/down proj and
|
||||
[hidden_out, hidden_in] for the others like router.
|
||||
When fused is False, loaded_weight should have shape of:
|
||||
[hidden_out, hidden_in].
|
||||
|
||||
Returns:
|
||||
True if loaded_weight is one of MoE weights and the MoE expert
|
||||
weights are loaded successfully, False otherwise.
|
||||
"""
|
||||
|
||||
# Whether the MoE expert weights are loaded successfully.
|
||||
expert_param_loaded = False
|
||||
|
||||
# If fused is True, the loaded weight is in the layout of:
|
||||
# [num_experts, hidden_in, hidden_out], so we must transpose the last
|
||||
# two dimensions to match the expected layout of the parameters.
|
||||
if fused and loaded_weight.ndim == 3:
|
||||
loaded_weight = loaded_weight.transpose(-1, -2)
|
||||
|
||||
# If the gate_proj and up_proj weights are fused into a single
|
||||
# weight tensor, we need to split the weight tensor into a tuple
|
||||
# of two weight tensors along the hidden_out dimension.
|
||||
if "experts.gate_up_proj" in name:
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-1)
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||
|
||||
# Iterate over all the expert parameters and load the weights if we find
|
||||
# a match in weight name.
|
||||
for (param_name, weight_name, expert_id,
|
||||
shard_id) in expert_params_mapping:
|
||||
|
||||
# Get a view of the loaded_weight to avoid modifying the original
|
||||
# one across iterations.
|
||||
new_loaded_weight = loaded_weight
|
||||
|
||||
# If expert weights are fused into a single weight tensor, remove
|
||||
# the expert index from the expected weight name.
|
||||
if fused:
|
||||
# The string between e_str and proj_str is the expert index.
|
||||
e_str, _, proj_str, _ = weight_name.split('.')
|
||||
weight_name = f"{e_str}.{proj_str}"
|
||||
param_name = f"{param_name}weight"
|
||||
|
||||
# Skip if the current weight is not one of the MoE weights.
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Replace the weight name with the parameter name.
|
||||
full_param_name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
|
||||
# Skip if the current weight corresponds to a parameter that
|
||||
# does not exist on the current PP (pipeline parallel) rank.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Skip if the current weight is for the bias.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
|
||||
param = params_dict[full_param_name]
|
||||
weight_loader = param.weight_loader
|
||||
|
||||
if fused:
|
||||
# If the parameter is for w13 together, the corresponding weight
|
||||
# will be a tuple, so we must select the correct weight
|
||||
# depending on the shard id, which is either "w1" or "w3".
|
||||
if "w13" in full_param_name:
|
||||
assert shard_id in ["w1", "w3"]
|
||||
shard_idx = 0 if shard_id == "w1" else 1
|
||||
new_loaded_weight = new_loaded_weight[shard_idx]
|
||||
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
|
||||
|
||||
# If EP (expert parallel) is enabled, update expert_id to the
|
||||
# starting expert index for the current EP rank and extract the
|
||||
# corresponding expert weights.
|
||||
layer_idx = extract_layer_index(name)
|
||||
# EP mapping
|
||||
expert_map = self.layers[
|
||||
layer_idx].feed_forward.experts.expert_map
|
||||
if expert_map is not None:
|
||||
@ -382,6 +442,9 @@ class Llama4Model(LlamaModel):
|
||||
else:
|
||||
# TODO: add EP support for non fused weights
|
||||
pass
|
||||
|
||||
# Load the weight into the module parameter with corresponding
|
||||
# shard id and expert id.
|
||||
weight_loader(param,
|
||||
new_loaded_weight,
|
||||
full_param_name,
|
||||
@ -390,10 +453,13 @@ class Llama4Model(LlamaModel):
|
||||
|
||||
loaded_params.add(full_param_name)
|
||||
expert_param_loaded = True
|
||||
|
||||
return expert_param_loaded
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
# Name mapping from the parameter name to the shard name and
|
||||
# corresponding shard id.
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -402,26 +468,43 @@ class Llama4Model(LlamaModel):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
# Indicate whether the expert weights are fused into a single weight
|
||||
# tensor.
|
||||
fused_experts_params = False
|
||||
# Expert parameter mapping for the case where the expert weights are
|
||||
# not fused into a single weight tensor.
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.num_experts)
|
||||
# Expert parameter mapping for the case where the expert weights are
|
||||
# fused into a single weight tensor.
|
||||
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="gate_up_proj",
|
||||
num_experts=1)
|
||||
# All the module parameters.
|
||||
params_dict = dict(self.named_parameters())
|
||||
# The module parameters that have been loaded.
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
# Iterate over all the weights and load them into module parameters.
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
# If the name contains "experts.gate_up_proj" or "experts.down_proj"
|
||||
# without the expert indices, it means the expert weights are fused
|
||||
# into a single weight tensor across all experts.
|
||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||
fused_experts_params = True
|
||||
expert_params_mapping = expert_params_mapping_fused
|
||||
|
||||
# If kv cache quantization scales exist and the weight name
|
||||
# corresponds to one of the kv cache quantization scales, load
|
||||
# them.
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
@ -430,84 +513,119 @@ class Llama4Model(LlamaModel):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
# Iterate over stacked_params_mapping to check if the current weight
|
||||
# is one of the stacked parameters. If so, load the weight with the
|
||||
# corresponding shard id. Note that MoE weights are handled
|
||||
# separately in the else block.
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip if the current weight is not one of the stacked
|
||||
# parameters or if the current weight is a MoE weight.
|
||||
if weight_name not in name or "experts" in name:
|
||||
continue
|
||||
# This check is for ModelOpt ckpts with kv cache quant enabled
|
||||
|
||||
# For ModelOpt checkpoints, we need to rename the self_attn
|
||||
# weight/weight_scale names except for kv cache scales.
|
||||
if not (name.endswith(
|
||||
(".k_scale", ".v_scale")) and "self_attn" in name):
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
# Skip if the current weight corresponds to a parameter that
|
||||
# does not exist on the current PP (pipeline parallel) rank.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale") and "expert" not in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
|
||||
# Remap kv cache scale names for ModelOpt checkpoints.
|
||||
# TODO: ModelOpt should implement get_cache_scale() such that
|
||||
# kv cache scale name remapping can be done there.
|
||||
if name.endswith("scale"):
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
# Load the weight into the module parameter with corresponding
|
||||
# shard id and exit the for loop and the else block.
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
loaded_params.add(name)
|
||||
break
|
||||
|
||||
# Handle normal (non-stacked) weights and MoE weights.
|
||||
else:
|
||||
moe_loaded = self.load_moe_expert_weights(
|
||||
name,
|
||||
# First, try to load MoE weights using load_moe_expert_weights.
|
||||
# If successful, move on to next loaded weight.
|
||||
if self.load_moe_expert_weights(name,
|
||||
loaded_weight,
|
||||
params_dict,
|
||||
loaded_params,
|
||||
expert_params_mapping,
|
||||
fused=fused_experts_params)
|
||||
fused=fused_experts_params):
|
||||
continue
|
||||
|
||||
if not moe_loaded:
|
||||
# Skip if the current weight corresponds to a parameter that
|
||||
# does not exist on the current PP (pipeline parallel) rank.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Handle flat expert scale parameters that
|
||||
# don't match per-expert patterns
|
||||
if ("experts." in name and ("w13_input_scale" in name
|
||||
or "w13_weight_scale" in name
|
||||
or "w2_input_scale" in name
|
||||
or "w2_weight_scale" in name)):
|
||||
# These are flat expert scales that apply to all experts
|
||||
# Handle flat expert scale parameters that don't match
|
||||
# per-expert patterns, i.e. one weight scale tensor for all
|
||||
# experts.
|
||||
scale_names = [
|
||||
"w13_input_scale", "w13_weight_scale", "w2_input_scale",
|
||||
"w2_weight_scale"
|
||||
]
|
||||
if ("experts." in name and any(scale_name in name
|
||||
for scale_name in scale_names)):
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
# Check for MoE-specific loading support via
|
||||
# attribute instead of expensive runtime reflection
|
||||
supports_moe = getattr(weight_loader,
|
||||
'supports_moe_loading', False)
|
||||
# If weight loader supports special moe loading, use it to
|
||||
# avoid expensive runtime reflection
|
||||
if getattr(weight_loader, 'supports_moe_loading', False):
|
||||
# Map the weight name to the corresponding shard id.
|
||||
shard_id = "w2" if "w2_" in name else "w1"
|
||||
|
||||
if supports_moe:
|
||||
# This is a MoE weight loader
|
||||
if "w13_" in name:
|
||||
shard_id = "w1"
|
||||
elif "w2_" in name:
|
||||
shard_id = "w2"
|
||||
else:
|
||||
shard_id = "w1"
|
||||
# Transpose if weight scales are FP8 block scales with
|
||||
# three dimensions:
|
||||
# [num_experts, hidden_in, hidden_out].
|
||||
if name.endswith("weight_scale") \
|
||||
and loaded_weight.dtype == torch.float8_e4m3fn \
|
||||
and loaded_weight.ndim == 3:
|
||||
loaded_weight = loaded_weight.transpose(-1, -2)
|
||||
|
||||
# Load the weight into the module parameter with
|
||||
# corresponding shard id and expert id.
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=0)
|
||||
|
||||
else:
|
||||
# Regular weight loader (handles both
|
||||
# param.weight_loader and default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
# Handle normal (non-stacked, non-MoE) weights.
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
# Finally, return the set of loaded parameters.
|
||||
return loaded_params
|
||||
|
||||
|
||||
@ -560,23 +678,43 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
# Helper function to permute the weight's channels
|
||||
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
|
||||
|
||||
# Calculate the expected shape of the weight.
|
||||
# Do not rely on w's shape, as it may be in another layout.
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
# If the weight is FP4 packed as uint8, we need to divide attn_out
|
||||
# by 2.
|
||||
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
|
||||
attn_out = attn_out // 2
|
||||
|
||||
# If the weight is a weight scale, we need to divide attn_out by
|
||||
# block size, which is currently 16.
|
||||
elif w.dtype == torch.float8_e4m3fn and is_weight_scale \
|
||||
and w.shape[1] * 16 == attn_out:
|
||||
attn_out = attn_out // 16
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if ("wk" in modules or "k_proj" in modules) \
|
||||
and modules[-1] == "weight":
|
||||
# Permute Q/K weights and weight block scales for rotary embedding
|
||||
is_weight = modules[-1] == "weight"
|
||||
is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and
|
||||
loaded_weight.dtype == torch.float8_e4m3fn)
|
||||
|
||||
if is_weight or is_nvfp4_weight_scale:
|
||||
if ("wk" in modules or "k_proj" in modules):
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif ("wq" in modules or "q_proj" in modules) \
|
||||
and modules[-1] == "weight":
|
||||
self.config.num_key_value_heads,
|
||||
is_nvfp4_weight_scale)
|
||||
elif ("wq" in modules or "q_proj" in modules):
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
self.config.num_attention_heads,
|
||||
is_nvfp4_weight_scale)
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user