mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:35:50 +08:00
[NVIDIA] Fix Llama4 Scout FP4 functionality issues (#21499)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
8f4a1c9a04
commit
ff08e51940
@ -874,6 +874,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
elif shard_id == "w2":
|
elif shard_id == "w2":
|
||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
def _load_w13_weight_scale(self, shard_dim: int,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
param: torch.Tensor, tp_rank: int):
|
||||||
|
shard_size = param.shape[shard_dim]
|
||||||
|
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||||
|
shard_size)
|
||||||
|
param.copy_(loaded_weight)
|
||||||
|
|
||||||
def _load_model_weight_or_group_weight_scale(self,
|
def _load_model_weight_or_group_weight_scale(self,
|
||||||
shard_dim: int,
|
shard_dim: int,
|
||||||
expert_data: torch.Tensor,
|
expert_data: torch.Tensor,
|
||||||
@ -1123,7 +1131,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
"weight_scale_2" in weight_name if uses_weight_scale_2 else
|
"weight_scale_2" in weight_name if uses_weight_scale_2 else
|
||||||
"weight_scale" in weight_name) or "input_scale" in weight_name
|
"weight_scale" in weight_name) or "input_scale" in weight_name
|
||||||
|
|
||||||
if per_tensor_conditions:
|
if "w13_weight_scale" in weight_name:
|
||||||
|
self._load_w13_weight_scale(shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
param=param,
|
||||||
|
tp_rank=self.tp_rank)
|
||||||
|
elif per_tensor_conditions:
|
||||||
self._load_per_tensor_weight_scale(
|
self._load_per_tensor_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
param=param,
|
param=param,
|
||||||
|
|||||||
@ -778,8 +778,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
# Swizzle the weight blockscale.
|
# Swizzle the weight blockscale.
|
||||||
# contracting dimension is input dimension
|
# contracting dimension is input dimension
|
||||||
# block_size = 16;
|
# block_size = 16;
|
||||||
assert (layer.weight_scale.shape[1] % 16 == 0), (
|
|
||||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
|
||||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
"Weight Block scale must be represented as FP8-E4M3")
|
"Weight Block scale must be represented as FP8-E4M3")
|
||||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||||
|
|||||||
@ -342,34 +342,94 @@ class Llama4Model(LlamaModel):
|
|||||||
expert_params_mapping: list[tuple[str, str, int, str]],
|
expert_params_mapping: list[tuple[str, str, int, str]],
|
||||||
fused: bool = True,
|
fused: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Load MoE expert weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the weight to load.
|
||||||
|
loaded_weight: The weight to load.
|
||||||
|
params_dict: The dictionary of module parameters.
|
||||||
|
loaded_params: The set of already loaded parameters.
|
||||||
|
expert_params_mapping: The mapping of expert parameters. Must be
|
||||||
|
generated by FusedMoE.make_expert_params_mapping().
|
||||||
|
fused: Whether the expert weights are fused into a single weight
|
||||||
|
tensor or are separate weight tensors for each expert.
|
||||||
|
When fused is True, loaded_weight should have shape of:
|
||||||
|
[num_experts, hidden_in, hidden_out] for gate/up/down proj and
|
||||||
|
[hidden_out, hidden_in] for the others like router.
|
||||||
|
When fused is False, loaded_weight should have shape of:
|
||||||
|
[hidden_out, hidden_in].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if loaded_weight is one of MoE weights and the MoE expert
|
||||||
|
weights are loaded successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Whether the MoE expert weights are loaded successfully.
|
||||||
expert_param_loaded = False
|
expert_param_loaded = False
|
||||||
if "experts.gate_up_proj" in name:
|
|
||||||
loaded_weight = loaded_weight.chunk(2, dim=-1)
|
# If fused is True, the loaded weight is in the layout of:
|
||||||
|
# [num_experts, hidden_in, hidden_out], so we must transpose the last
|
||||||
|
# two dimensions to match the expected layout of the parameters.
|
||||||
|
if fused and loaded_weight.ndim == 3:
|
||||||
|
loaded_weight = loaded_weight.transpose(-1, -2)
|
||||||
|
|
||||||
|
# If the gate_proj and up_proj weights are fused into a single
|
||||||
|
# weight tensor, we need to split the weight tensor into a tuple
|
||||||
|
# of two weight tensors along the hidden_out dimension.
|
||||||
|
if "experts.gate_up_proj" in name:
|
||||||
|
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||||
|
|
||||||
|
# Iterate over all the expert parameters and load the weights if we find
|
||||||
|
# a match in weight name.
|
||||||
for (param_name, weight_name, expert_id,
|
for (param_name, weight_name, expert_id,
|
||||||
shard_id) in expert_params_mapping:
|
shard_id) in expert_params_mapping:
|
||||||
|
|
||||||
|
# Get a view of the loaded_weight to avoid modifying the original
|
||||||
|
# one across iterations.
|
||||||
new_loaded_weight = loaded_weight
|
new_loaded_weight = loaded_weight
|
||||||
|
|
||||||
|
# If expert weights are fused into a single weight tensor, remove
|
||||||
|
# the expert index from the expected weight name.
|
||||||
if fused:
|
if fused:
|
||||||
|
# The string between e_str and proj_str is the expert index.
|
||||||
e_str, _, proj_str, _ = weight_name.split('.')
|
e_str, _, proj_str, _ = weight_name.split('.')
|
||||||
weight_name = f"{e_str}.{proj_str}"
|
weight_name = f"{e_str}.{proj_str}"
|
||||||
param_name = f"{param_name}weight"
|
param_name = f"{param_name}weight"
|
||||||
|
|
||||||
|
# Skip if the current weight is not one of the MoE weights.
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Replace the weight name with the parameter name.
|
||||||
full_param_name = name.replace(weight_name, param_name)
|
full_param_name = name.replace(weight_name, param_name)
|
||||||
# Skip layers on other devices.
|
|
||||||
|
# Skip if the current weight corresponds to a parameter that
|
||||||
|
# does not exist on the current PP (pipeline parallel) rank.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip if the current weight is for the bias.
|
||||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
and name not in params_dict):
|
and name not in params_dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[full_param_name]
|
param = params_dict[full_param_name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
|
|
||||||
if fused:
|
if fused:
|
||||||
|
# If the parameter is for w13 together, the corresponding weight
|
||||||
|
# will be a tuple, so we must select the correct weight
|
||||||
|
# depending on the shard id, which is either "w1" or "w3".
|
||||||
if "w13" in full_param_name:
|
if "w13" in full_param_name:
|
||||||
|
assert shard_id in ["w1", "w3"]
|
||||||
shard_idx = 0 if shard_id == "w1" else 1
|
shard_idx = 0 if shard_id == "w1" else 1
|
||||||
new_loaded_weight = new_loaded_weight[shard_idx]
|
new_loaded_weight = new_loaded_weight[shard_idx]
|
||||||
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
|
|
||||||
|
# If EP (expert parallel) is enabled, update expert_id to the
|
||||||
|
# starting expert index for the current EP rank and extract the
|
||||||
|
# corresponding expert weights.
|
||||||
layer_idx = extract_layer_index(name)
|
layer_idx = extract_layer_index(name)
|
||||||
# EP mapping
|
|
||||||
expert_map = self.layers[
|
expert_map = self.layers[
|
||||||
layer_idx].feed_forward.experts.expert_map
|
layer_idx].feed_forward.experts.expert_map
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
@ -382,6 +442,9 @@ class Llama4Model(LlamaModel):
|
|||||||
else:
|
else:
|
||||||
# TODO: add EP support for non fused weights
|
# TODO: add EP support for non fused weights
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Load the weight into the module parameter with corresponding
|
||||||
|
# shard id and expert id.
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
new_loaded_weight,
|
new_loaded_weight,
|
||||||
full_param_name,
|
full_param_name,
|
||||||
@ -390,10 +453,13 @@ class Llama4Model(LlamaModel):
|
|||||||
|
|
||||||
loaded_params.add(full_param_name)
|
loaded_params.add(full_param_name)
|
||||||
expert_param_loaded = True
|
expert_param_loaded = True
|
||||||
|
|
||||||
return expert_param_loaded
|
return expert_param_loaded
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
|
# Name mapping from the parameter name to the shard name and
|
||||||
|
# corresponding shard id.
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
@ -402,26 +468,43 @@ class Llama4Model(LlamaModel):
|
|||||||
(".gate_up_proj", ".gate_proj", 0),
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
]
|
]
|
||||||
|
# Indicate whether the expert weights are fused into a single weight
|
||||||
|
# tensor.
|
||||||
fused_experts_params = False
|
fused_experts_params = False
|
||||||
|
# Expert parameter mapping for the case where the expert weights are
|
||||||
|
# not fused into a single weight tensor.
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
num_experts=self.num_experts)
|
num_experts=self.num_experts)
|
||||||
|
# Expert parameter mapping for the case where the expert weights are
|
||||||
|
# fused into a single weight tensor.
|
||||||
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
|
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_up_proj",
|
ckpt_gate_proj_name="gate_up_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="gate_up_proj",
|
ckpt_up_proj_name="gate_up_proj",
|
||||||
num_experts=1)
|
num_experts=1)
|
||||||
|
# All the module parameters.
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
# The module parameters that have been loaded.
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
|
# Iterate over all the weights and load them into module parameters.
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
|
||||||
|
# If the name contains "experts.gate_up_proj" or "experts.down_proj"
|
||||||
|
# without the expert indices, it means the expert weights are fused
|
||||||
|
# into a single weight tensor across all experts.
|
||||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||||
fused_experts_params = True
|
fused_experts_params = True
|
||||||
expert_params_mapping = expert_params_mapping_fused
|
expert_params_mapping = expert_params_mapping_fused
|
||||||
|
|
||||||
|
# If kv cache quantization scales exist and the weight name
|
||||||
|
# corresponds to one of the kv cache quantization scales, load
|
||||||
|
# them.
|
||||||
if (self.quant_config is not None and
|
if (self.quant_config is not None and
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
param = params_dict[scale_name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
@ -430,84 +513,119 @@ class Llama4Model(LlamaModel):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Iterate over stacked_params_mapping to check if the current weight
|
||||||
|
# is one of the stacked parameters. If so, load the weight with the
|
||||||
|
# corresponding shard id. Note that MoE weights are handled
|
||||||
|
# separately in the else block.
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
# Skip if the current weight is not one of the stacked
|
||||||
|
# parameters or if the current weight is a MoE weight.
|
||||||
if weight_name not in name or "experts" in name:
|
if weight_name not in name or "experts" in name:
|
||||||
continue
|
continue
|
||||||
# This check is for ModelOpt ckpts with kv cache quant enabled
|
|
||||||
|
# For ModelOpt checkpoints, we need to rename the self_attn
|
||||||
|
# weight/weight_scale names except for kv cache scales.
|
||||||
if not (name.endswith(
|
if not (name.endswith(
|
||||||
(".k_scale", ".v_scale")) and "self_attn" in name):
|
(".k_scale", ".v_scale")) and "self_attn" in name):
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
# Skip if the current weight corresponds to a parameter that
|
||||||
|
# does not exist on the current PP (pipeline parallel) rank.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
if name.endswith("scale") and "expert" not in name:
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
# Remap kv cache scale names for ModelOpt checkpoints.
|
||||||
|
# TODO: ModelOpt should implement get_cache_scale() such that
|
||||||
|
# kv cache scale name remapping can be done there.
|
||||||
|
if name.endswith("scale"):
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
if name is None:
|
if name is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Load the weight into the module parameter with corresponding
|
||||||
|
# shard id and exit the for loop and the else block.
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|
||||||
if weight_loader == default_weight_loader:
|
if weight_loader == default_weight_loader:
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
else:
|
else:
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Handle normal (non-stacked) weights and MoE weights.
|
||||||
else:
|
else:
|
||||||
moe_loaded = self.load_moe_expert_weights(
|
# First, try to load MoE weights using load_moe_expert_weights.
|
||||||
name,
|
# If successful, move on to next loaded weight.
|
||||||
loaded_weight,
|
if self.load_moe_expert_weights(name,
|
||||||
params_dict,
|
loaded_weight,
|
||||||
loaded_params,
|
params_dict,
|
||||||
expert_params_mapping,
|
loaded_params,
|
||||||
fused=fused_experts_params)
|
expert_params_mapping,
|
||||||
|
fused=fused_experts_params):
|
||||||
|
continue
|
||||||
|
|
||||||
if not moe_loaded:
|
# Skip if the current weight corresponds to a parameter that
|
||||||
if is_pp_missing_parameter(name, self):
|
# does not exist on the current PP (pipeline parallel) rank.
|
||||||
continue
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
# Handle flat expert scale parameters that
|
# Handle flat expert scale parameters that don't match
|
||||||
# don't match per-expert patterns
|
# per-expert patterns, i.e. one weight scale tensor for all
|
||||||
if ("experts." in name and ("w13_input_scale" in name
|
# experts.
|
||||||
or "w13_weight_scale" in name
|
scale_names = [
|
||||||
or "w2_input_scale" in name
|
"w13_input_scale", "w13_weight_scale", "w2_input_scale",
|
||||||
or "w2_weight_scale" in name)):
|
"w2_weight_scale"
|
||||||
# These are flat expert scales that apply to all experts
|
]
|
||||||
param = params_dict[name]
|
if ("experts." in name and any(scale_name in name
|
||||||
weight_loader = getattr(param, "weight_loader",
|
for scale_name in scale_names)):
|
||||||
default_weight_loader)
|
|
||||||
|
|
||||||
# Check for MoE-specific loading support via
|
|
||||||
# attribute instead of expensive runtime reflection
|
|
||||||
supports_moe = getattr(weight_loader,
|
|
||||||
'supports_moe_loading', False)
|
|
||||||
|
|
||||||
if supports_moe:
|
|
||||||
# This is a MoE weight loader
|
|
||||||
if "w13_" in name:
|
|
||||||
shard_id = "w1"
|
|
||||||
elif "w2_" in name:
|
|
||||||
shard_id = "w2"
|
|
||||||
else:
|
|
||||||
shard_id = "w1"
|
|
||||||
|
|
||||||
weight_loader(param,
|
|
||||||
loaded_weight,
|
|
||||||
name,
|
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=0)
|
|
||||||
else:
|
|
||||||
# Regular weight loader (handles both
|
|
||||||
# param.weight_loader and default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
# If weight loader supports special moe loading, use it to
|
||||||
|
# avoid expensive runtime reflection
|
||||||
|
if getattr(weight_loader, 'supports_moe_loading', False):
|
||||||
|
# Map the weight name to the corresponding shard id.
|
||||||
|
shard_id = "w2" if "w2_" in name else "w1"
|
||||||
|
|
||||||
|
# Transpose if weight scales are FP8 block scales with
|
||||||
|
# three dimensions:
|
||||||
|
# [num_experts, hidden_in, hidden_out].
|
||||||
|
if name.endswith("weight_scale") \
|
||||||
|
and loaded_weight.dtype == torch.float8_e4m3fn \
|
||||||
|
and loaded_weight.ndim == 3:
|
||||||
|
loaded_weight = loaded_weight.transpose(-1, -2)
|
||||||
|
|
||||||
|
# Load the weight into the module parameter with
|
||||||
|
# corresponding shard id and expert id.
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Regular weight loader (handles both
|
||||||
|
# param.weight_loader and default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle normal (non-stacked, non-MoE) weights.
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
# Finally, return the set of loaded parameters.
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
@ -560,23 +678,43 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
|||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
) -> tuple[str, torch.Tensor]:
|
) -> tuple[str, torch.Tensor]:
|
||||||
|
|
||||||
def permute(w: torch.Tensor, n_heads: int):
|
# Helper function to permute the weight's channels
|
||||||
|
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
|
||||||
|
|
||||||
|
# Calculate the expected shape of the weight.
|
||||||
|
# Do not rely on w's shape, as it may be in another layout.
|
||||||
attn_in = self.config.head_dim * n_heads
|
attn_in = self.config.head_dim * n_heads
|
||||||
attn_out = self.config.hidden_size
|
attn_out = self.config.hidden_size
|
||||||
|
|
||||||
|
# If the weight is FP4 packed as uint8, we need to divide attn_out
|
||||||
|
# by 2.
|
||||||
|
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
|
||||||
|
attn_out = attn_out // 2
|
||||||
|
|
||||||
|
# If the weight is a weight scale, we need to divide attn_out by
|
||||||
|
# block size, which is currently 16.
|
||||||
|
elif w.dtype == torch.float8_e4m3fn and is_weight_scale \
|
||||||
|
and w.shape[1] * 16 == attn_out:
|
||||||
|
attn_out = attn_out // 16
|
||||||
|
|
||||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||||
|
|
||||||
modules = name.split(".")
|
modules = name.split(".")
|
||||||
|
|
||||||
# rotary embeds should be sliced
|
# Permute Q/K weights and weight block scales for rotary embedding
|
||||||
if ("wk" in modules or "k_proj" in modules) \
|
is_weight = modules[-1] == "weight"
|
||||||
and modules[-1] == "weight":
|
is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and
|
||||||
loaded_weight = permute(loaded_weight,
|
loaded_weight.dtype == torch.float8_e4m3fn)
|
||||||
self.config.num_key_value_heads)
|
|
||||||
elif ("wq" in modules or "q_proj" in modules) \
|
if is_weight or is_nvfp4_weight_scale:
|
||||||
and modules[-1] == "weight":
|
if ("wk" in modules or "k_proj" in modules):
|
||||||
loaded_weight = permute(loaded_weight,
|
loaded_weight = permute(loaded_weight,
|
||||||
self.config.num_attention_heads)
|
self.config.num_key_value_heads,
|
||||||
|
is_nvfp4_weight_scale)
|
||||||
|
elif ("wq" in modules or "q_proj" in modules):
|
||||||
|
loaded_weight = permute(loaded_weight,
|
||||||
|
self.config.num_attention_heads,
|
||||||
|
is_nvfp4_weight_scale)
|
||||||
|
|
||||||
return name, loaded_weight
|
return name, loaded_weight
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user