From b06b9470ca881f89feea72e4d89b20c213f360d4 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 10 Nov 2025 23:38:56 +0800 Subject: [PATCH] [Rocm][fused_moe][fp4] view weight to torch.float4_e2m1fn_x2 when running aiter fused moe for fp4 model (#27474) Signed-off-by: zejunchen-zejun --- .../layers/quantization/quark/quark_moe.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8825611051e5d..eca6b0cb1d8e5 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -458,6 +458,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") + self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( self.input_dtype, self.weight_dtype @@ -581,6 +582,17 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) w2_weight_scale = e8m0_shuffle(w2_weight_scale) layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + + if self.fp4_dtype is not None: + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight.view(self.fp4_dtype), + requires_grad=layer.w13_weight.requires_grad, + ) + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight.view(self.fp4_dtype), + requires_grad=layer.w2_weight.requires_grad, + ) + torch.cuda.empty_cache() def get_fused_moe_quant_config(