mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 23:29:08 +08:00
Fix MTP with deepep_low_latency (#25904)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
502640c3f9
commit
13cdc02173
@ -1899,6 +1899,15 @@ class FusedMoE(CustomOp):
|
|||||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||||
|
|
||||||
|
# If there are shared experts but we are not using a modular kernel,
|
||||||
|
# the shared experts must be called here
|
||||||
|
if (not isinstance(self.quant_method.fused_experts,
|
||||||
|
FusedMoEModularKernel)
|
||||||
|
and self.shared_experts is not None):
|
||||||
|
shared_output = self.shared_experts(staged_hidden_states)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
@ -1922,8 +1931,13 @@ class FusedMoE(CustomOp):
|
|||||||
logical_replica_count=self.logical_replica_count,
|
logical_replica_count=self.logical_replica_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert self.shared_experts is None or isinstance(
|
if shared_output is not None:
|
||||||
final_hidden_states, tuple)
|
assert not isinstance(final_hidden_states, tuple)
|
||||||
|
assert self.shared_experts is not None
|
||||||
|
final_hidden_states = (
|
||||||
|
shared_output,
|
||||||
|
final_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||||
assert isinstance(final_hidden_states, tuple)
|
assert isinstance(final_hidden_states, tuple)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user