mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 20:54:59 +08:00
[FIXBUG] Add return_success parameter to moe_wna16_weight_loader function (#22797)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
c5d004aaaf
commit
3462c1c522
@ -124,7 +124,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
awq_min_capability = AWQConfig.get_min_capability()
|
awq_min_capability = AWQConfig.get_min_capability()
|
||||||
|
|
||||||
gptq_compatible = quant_method == "gptq" and \
|
gptq_compatible = quant_method == "gptq" and \
|
||||||
not desc_act and num_bits in [4, 8]
|
not desc_act and num_bits in [4, 8]
|
||||||
awq_compatible = quant_method == "awq" and num_bits == 4 and \
|
awq_compatible = quant_method == "awq" and num_bits == 4 and \
|
||||||
device_capability >= awq_min_capability
|
device_capability >= awq_min_capability
|
||||||
|
|
||||||
@ -175,11 +175,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, quant_config: MoeWNA16Config,
|
||||||
self,
|
moe: "FusedMoEConfig") -> None:
|
||||||
quant_config: MoeWNA16Config,
|
|
||||||
moe: FusedMoEConfig,
|
|
||||||
):
|
|
||||||
super().__init__(moe)
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
@ -187,6 +184,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
self.moe = layer
|
||||||
layer.quant_config = self.quant_config
|
layer.quant_config = self.quant_config
|
||||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||||
group_size = self.quant_config.group_size
|
group_size = self.quant_config.group_size
|
||||||
@ -308,7 +306,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert self.fused_experts is None
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `MoeWNA16Method` yet.")
|
"EPLB not supported for `MoeWNA16Method` yet.")
|
||||||
@ -404,12 +401,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
weight_name: str, shard_id: str,
|
weight_name: str,
|
||||||
expert_id: int):
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
|
return_success: bool = False):
|
||||||
if "g_idx" in weight_name:
|
if "g_idx" in weight_name:
|
||||||
return
|
return False if return_success else None
|
||||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||||
return
|
return False if return_success else None
|
||||||
|
|
||||||
device = get_tp_group().device
|
device = get_tp_group().device
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
@ -455,11 +454,18 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
param.data[expert_id, :shard_size // 2] = tensor
|
param.data[expert_id, :shard_size // 2] = tensor
|
||||||
else:
|
else:
|
||||||
param.data[expert_id, shard_size // 2:] = tensor
|
param.data[expert_id, shard_size // 2:] = tensor
|
||||||
|
return True if return_success else None
|
||||||
elif "w2_qzeros" in weight_name:
|
elif "w2_qzeros" in weight_name:
|
||||||
param.data[expert_id] = loaded_weight.view(
|
param.data[expert_id] = loaded_weight.view(
|
||||||
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||||
|
return True if return_success else None
|
||||||
else:
|
else:
|
||||||
weight_loader(param, loaded_weight, weight_name, shard_id,
|
# Delegate to the original loader, passing return_success
|
||||||
expert_id)
|
return weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
weight_name,
|
||||||
|
shard_id,
|
||||||
|
expert_id,
|
||||||
|
return_success=return_success)
|
||||||
|
|
||||||
return moe_wna16_weight_loader
|
return moe_wna16_weight_loader
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user