mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 04:54:30 +08:00
[Rocm][fused_moe][fp4] view weight to torch.float4_e2m1fn_x2 when running aiter fused moe for fp4 model (#27474)
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
This commit is contained in:
parent
4673e465ff
commit
b06b9470ca
@ -458,6 +458,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
|
||||
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
|
||||
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
|
||||
|
||||
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
|
||||
self.input_dtype, self.weight_dtype
|
||||
@ -581,6 +582,17 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
|
||||
if self.fp4_dtype is not None:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
layer.w13_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w13_weight.requires_grad,
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
layer.w2_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w2_weight.requires_grad,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user