[BugFix] Fix misprint introduced by modular_kernel refactoring. (#28728)

Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
This commit is contained in:
Andrey Khalyavin 2025-11-14 21:58:18 +03:00 committed by GitHub
parent cec275efce
commit fd4555089a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1060,7 +1060,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
expert_map=expert_map,
a1q_scale=_slice_scales(a1q_scale, s, e),
a2_scale=_slice_scales(self.fused_experts.a2_scale, e, e),
a2_scale=_slice_scales(self.fused_experts.a2_scale, s, e),
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta,