[Bugfix] Fix for 24530. Fix naive all2all shared expert overlap. (#24538)

This commit is contained in:
bnellnm 2025-09-09 20:53:53 -04:00 committed by GitHub
parent 561f38dc3c
commit b23fb78623
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1755,9 +1755,6 @@ class FusedMoE(CustomOp):
self.dp_size > 1 self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_config.use_flashinfer_cutlass_kernels) and not self.moe_config.use_flashinfer_cutlass_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# If there are shared experts but we are not using a modular kernel, the # If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here # shared experts must be called here
@ -1768,6 +1765,10 @@ class FusedMoE(CustomOp):
else: else:
shared_output = None shared_output = None
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
@ -1800,8 +1801,9 @@ class FusedMoE(CustomOp):
final_hidden_states, final_hidden_states,
) )
def reduce_output(states: torch.Tensor) -> torch.Tensor: def reduce_output(states: torch.Tensor,
if do_naive_dispatch_combine: do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states) states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
@ -1810,10 +1812,11 @@ class FusedMoE(CustomOp):
return states return states
if self.shared_experts is None: if self.shared_experts is None:
assert not isinstance(final_hidden_states, tuple)
return reduce_output(final_hidden_states) return reduce_output(final_hidden_states)
else: else:
return ( return (
reduce_output(final_hidden_states[0]), reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]), reduce_output(final_hidden_states[1]),
) )