From 99e2379b163b7b9a6211fca38c219db010d51669 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 20 Oct 2025 22:00:01 -0700 Subject: [PATCH] low latency combine Signed-off-by: Zhuohan Li --- .../layers/moe/ep_kernels_no_abstraction.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py index 983ba95a7aa5b..22c6bf7147cc2 100644 --- a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py +++ b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py @@ -371,8 +371,8 @@ def run_low_latency(): all_topk_idx = torch.cat(all_topk_idx, dim=0) all_topk_weights = torch.cat(all_topk_weights, dim=0) + # Verification expert_tok_ids = [[] for _ in range(local_num_experts)] - expert_range_start = rank * local_num_experts for i, cnt in enumerate(recv_expert_count): for j in range(cnt): @@ -386,6 +386,19 @@ def run_low_latency(): assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist() + # Combine + combined_hidden_states, event_overlap, hook = low_latency_combine( + recv_hidden_states, topk_idx, topk_weights, handle + ) + hook() + + torch.testing.assert_close( + combined_hidden_states.to(torch.float32), + x * topk_weights.sum(dim=-1).unsqueeze(-1), + atol=1e-2, + rtol=1e-2, + ) + if __name__ == "__main__": torch.distributed.init_process_group(