[Rocm][fused_moe][fp4] view weight to torch.float4_e2m1fn_x2 when running aiter fused moe for fp4 model (#27474)

Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
This commit is contained in:
zejunchen-zejun 2025-11-10 23:38:56 +08:00 committed by GitHub
parent 4673e465ff
commit b06b9470ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -458,6 +458,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
@ -581,6 +582,17 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
if self.fp4_dtype is not None:
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight.view(self.fp4_dtype),
requires_grad=layer.w13_weight.requires_grad,
)
layer.w2_weight = torch.nn.Parameter(
layer.w2_weight.view(self.fp4_dtype),
requires_grad=layer.w2_weight.requires_grad,
)
torch.cuda.empty_cache()
def get_fused_moe_quant_config(