From da26dce7b240443316de47694bf2c3707f9bc0b6 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 20 Oct 2025 21:49:09 -0700 Subject: [PATCH] low latency dispatch Signed-off-by: Zhuohan Li --- .../layers/moe/ep_kernels_no_abstraction.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py index 324dc22919795..983ba95a7aa5b 100644 --- a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py +++ b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py @@ -293,6 +293,7 @@ def low_latency_dispatch( num_experts, async_finish=False, return_recv_hook=True, + use_fp8=False, ) ) @@ -327,6 +328,65 @@ def low_latency_combine( return combined_hidden_states, event_overlap, hook +def run_low_latency(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", 0)) + + group = dist.group.WORLD + num_experts = 8 + local_batch_size = 4 + batch_size = local_batch_size * world_size + hidden_size = 2048 + local_num_experts = num_experts // group.size() + x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16) + topk = 2 + num_max_dispatch_tokens_per_rank = local_batch_size + low_latency_get_buffer( + group, num_max_dispatch_tokens_per_rank, hidden_size, num_experts + ) + + expert_weights = torch.randn( + local_batch_size, + num_experts, + dtype=torch.float32, + device="cuda", + ) + topk_weights, topk_idx = torch.topk(expert_weights, topk, dim=1) + + recv_hidden_states, recv_expert_count, handle, event, hook = low_latency_dispatch( + x, topk_idx, num_max_dispatch_tokens_per_rank, num_experts + ) + hook() + + # Dispatch naive + all_x = [torch.empty_like(x) for _ in range(world_size)] + all_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)] + all_topk_weights = [torch.empty_like(topk_weights) for _ in range(world_size)] + + dist.all_gather(all_x, x) + dist.all_gather(all_topk_idx, topk_idx) + dist.all_gather(all_topk_weights, topk_weights) + + all_x = torch.cat(all_x, dim=0) + all_topk_idx = torch.cat(all_topk_idx, dim=0) + all_topk_weights = torch.cat(all_topk_weights, dim=0) + + expert_tok_ids = [[] for _ in range(local_num_experts)] + + expert_range_start = rank * local_num_experts + for i, cnt in enumerate(recv_expert_count): + for j in range(cnt): + for k in range(batch_size): + if torch.allclose( + recv_hidden_states[i, j], all_x[k], rtol=0.0, atol=0.0 + ): + expert_tok_ids[i].append(k) + assert i + expert_range_start in all_topk_idx[k] + break + + assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist() + + if __name__ == "__main__": torch.distributed.init_process_group( backend="nccl", @@ -335,3 +395,4 @@ if __name__ == "__main__": local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) run_high_throughput() + run_low_latency()