[feat] move WEIGHT_SCALE_SUPPORTED into raise block to accelerate RLHF weight loading (#21164)

Signed-off-by: huangweixiao <huangweixiao@msh.team>
This commit is contained in:
Weixiao Huang 2025-08-04 15:43:06 +08:00 committed by GitHub
parent a7b8788d2c
commit c1b4eb048a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1079,9 +1079,6 @@ class FusedMoE(torch.nn.Module):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size_per_partition is used.
@ -1230,6 +1227,9 @@ class FusedMoE(torch.nn.Module):
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return True if return_success else None