[Bugfix] Fix FusedMoEModularKernel for triton backend (#28913)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang 2025-11-18 21:05:22 -08:00 committed by GitHub
parent 4c23690f43
commit 468a8d72ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -755,8 +755,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w13_weight = w13_weight
self.w2_weight = w2_weight
layer.w13_weight = Parameter(w13_weight.storage.data, requires_grad=False)
layer.w2_weight = Parameter(w2_weight.storage.data, requires_grad=False)
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
@ -1065,8 +1067,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight,
w2=self.w2_weight,
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,