mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[Bugfix] Fix for 24530. Fix naive all2all shared expert overlap. (#24538)
This commit is contained in:
parent
561f38dc3c
commit
b23fb78623
@ -1755,9 +1755,6 @@ class FusedMoE(CustomOp):
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels
|
||||
and not self.moe_config.use_flashinfer_cutlass_kernels)
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
# shared experts must be called here
|
||||
@ -1768,6 +1765,10 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@ -1800,8 +1801,9 @@ class FusedMoE(CustomOp):
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
states = get_ep_group().combine(states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
@ -1810,10 +1812,11 @@ class FusedMoE(CustomOp):
|
||||
return states
|
||||
|
||||
if self.shared_experts is None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
return reduce_output(final_hidden_states)
|
||||
else:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0]),
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user