From 3462c1c522d214755f1dfce3d645ab5afe7f00ae Mon Sep 17 00:00:00 2001 From: JartX Date: Thu, 28 Aug 2025 11:03:22 +0200 Subject: [PATCH] [FIXBUG] Add return_success parameter to moe_wna16_weight_loader function (#22797) Signed-off-by: JartX Co-authored-by: Michael Goin --- .../layers/quantization/moe_wna16.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 364d1ac314d2d..0cde104cc75d7 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -124,7 +124,7 @@ class MoeWNA16Config(QuantizationConfig): awq_min_capability = AWQConfig.get_min_capability() gptq_compatible = quant_method == "gptq" and \ - not desc_act and num_bits in [4, 8] + not desc_act and num_bits in [4, 8] awq_compatible = quant_method == "awq" and num_bits == 4 and \ device_capability >= awq_min_capability @@ -175,11 +175,8 @@ class MoeWNA16Method(FusedMoEMethodBase): quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __init__( - self, - quant_config: MoeWNA16Config, - moe: FusedMoEConfig, - ): + def __init__(self, quant_config: MoeWNA16Config, + moe: "FusedMoEConfig") -> None: super().__init__(moe) self.quant_config = quant_config @@ -187,6 +184,7 @@ class MoeWNA16Method(FusedMoEMethodBase): hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + self.moe = layer layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size @@ -308,7 +306,6 @@ class MoeWNA16Method(FusedMoEMethodBase): logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `MoeWNA16Method` yet.") @@ -404,12 +401,14 @@ class MoeWNA16Method(FusedMoEMethodBase): def moe_wna16_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, shard_id: str, - expert_id: int): + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False): if "g_idx" in weight_name: - return + return False if return_success else None if not layer.quant_config.has_zp and "qzeros" in weight_name: - return + return False if return_success else None device = get_tp_group().device tp_rank = get_tensor_model_parallel_rank() @@ -455,11 +454,18 @@ class MoeWNA16Method(FusedMoEMethodBase): param.data[expert_id, :shard_size // 2] = tensor else: param.data[expert_id, shard_size // 2:] = tensor + return True if return_success else None elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + return True if return_success else None else: - weight_loader(param, loaded_weight, weight_name, shard_id, - expert_id) + # Delegate to the original loader, passing return_success + return weight_loader(param, + loaded_weight, + weight_name, + shard_id, + expert_id, + return_success=return_success) return moe_wna16_weight_loader