[FIXBUG] Add return_success parameter to moe_wna16_weight_loader function (#22797)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
JartX 2025-08-28 11:03:22 +02:00 committed by GitHub
parent c5d004aaaf
commit 3462c1c522
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -124,7 +124,7 @@ class MoeWNA16Config(QuantizationConfig):
awq_min_capability = AWQConfig.get_min_capability() awq_min_capability = AWQConfig.get_min_capability()
gptq_compatible = quant_method == "gptq" and \ gptq_compatible = quant_method == "gptq" and \
not desc_act and num_bits in [4, 8] not desc_act and num_bits in [4, 8]
awq_compatible = quant_method == "awq" and num_bits == 4 and \ awq_compatible = quant_method == "awq" and num_bits == 4 and \
device_capability >= awq_min_capability device_capability >= awq_min_capability
@ -175,11 +175,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
""" """
def __init__( def __init__(self, quant_config: MoeWNA16Config,
self, moe: "FusedMoEConfig") -> None:
quant_config: MoeWNA16Config,
moe: FusedMoEConfig,
):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
@ -187,6 +184,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
self.moe = layer
layer.quant_config = self.quant_config layer.quant_config = self.quant_config
bit8_pack_factor = self.quant_config.bit8_pack_factor bit8_pack_factor = self.quant_config.bit8_pack_factor
group_size = self.quant_config.group_size group_size = self.quant_config.group_size
@ -308,7 +306,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `MoeWNA16Method` yet.") "EPLB not supported for `MoeWNA16Method` yet.")
@ -404,12 +401,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
def moe_wna16_weight_loader(param: torch.nn.Parameter, def moe_wna16_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
weight_name: str, shard_id: str, weight_name: str,
expert_id: int): shard_id: str,
expert_id: int,
return_success: bool = False):
if "g_idx" in weight_name: if "g_idx" in weight_name:
return return False if return_success else None
if not layer.quant_config.has_zp and "qzeros" in weight_name: if not layer.quant_config.has_zp and "qzeros" in weight_name:
return return False if return_success else None
device = get_tp_group().device device = get_tp_group().device
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
@ -455,11 +454,18 @@ class MoeWNA16Method(FusedMoEMethodBase):
param.data[expert_id, :shard_size // 2] = tensor param.data[expert_id, :shard_size // 2] = tensor
else: else:
param.data[expert_id, shard_size // 2:] = tensor param.data[expert_id, shard_size // 2:] = tensor
return True if return_success else None
elif "w2_qzeros" in weight_name: elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view( param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
return True if return_success else None
else: else:
weight_loader(param, loaded_weight, weight_name, shard_id, # Delegate to the original loader, passing return_success
expert_id) return weight_loader(param,
loaded_weight,
weight_name,
shard_id,
expert_id,
return_success=return_success)
return moe_wna16_weight_loader return moe_wna16_weight_loader