[Bug] Fix DeepEP low latency assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) Bug (#27682)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-10-29 14:50:39 -04:00 committed by GitHub
parent accb8fab07
commit fcb1d570bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1135,6 +1135,7 @@ class FusedMoE(CustomOp):
)
self.global_num_experts = num_experts + num_redundant_experts
self.logical_num_experts = num_experts
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
@ -1998,13 +1999,12 @@ class FusedMoE(CustomOp):
moe = self.moe_config
# Note here we use `num_experts` which is logical expert count
if self.vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
else:
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, moe.num_experts)
logits_shape = (moe.max_num_tokens, self.logical_num_experts)
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()