Fix MTP with deepep_low_latency (#25904)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Matthew Bonanni 2025-10-02 17:29:49 -04:00 committed by yewentao256
parent abc55b1fe5
commit 72c5dd0310

View File

@ -1899,6 +1899,15 @@ class FusedMoE(CustomOp):
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here
if (not isinstance(self.quant_method.fused_experts,
FusedMoEModularKernel)
and self.shared_experts is not None):
shared_output = self.shared_experts(staged_hidden_states)
else:
shared_output = None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@ -1922,8 +1931,13 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
assert self.shared_experts is None or isinstance(
final_hidden_states, tuple)
if shared_output is not None:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
final_hidden_states = (
shared_output,
final_hidden_states,
)
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)