From 48dcc72d7ea675171a76107cb8090097de8d1ac6 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 20 Oct 2025 18:08:57 -0700 Subject: [PATCH] refactor and add low latency code Signed-off-by: Zhuohan Li --- .../layers/moe/ep_kernels_no_abstraction.py | 110 ++++++++++++++++-- 1 file changed, 98 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py index dcac696407422..324dc22919795 100644 --- a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py +++ b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py @@ -139,15 +139,9 @@ def combine_forward( return combined_x, event -if __name__ == "__main__": - torch.distributed.init_process_group( - backend="nccl", - init_method="env://", - ) +def run_high_throughput(): world_size = int(os.environ.get("WORLD_SIZE", 1)) - local_rank = int(os.environ.get("LOCAL_RANK", 0)) rank = int(os.environ.get("RANK", 0)) - torch.cuda.set_device(local_rank) group = dist.group.WORLD num_experts = 8 @@ -182,11 +176,6 @@ if __name__ == "__main__": topk_weights, num_experts, ) - recv_topk_global_idx = recv_topk_idx + torch.where( - recv_topk_idx == -1, - 0, - rank * local_num_experts, - ) # Dispatch naive all_x = [torch.empty_like(x) for _ in range(world_size)] @@ -249,3 +238,100 @@ if __name__ == "__main__": ): appear_count += 1.0 torch.testing.assert_close(combined_results[i], x[i] * appear_count) + + +# You may call this function at the framework initialization +def low_latency_get_buffer( + group: dist.ProcessGroup, + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_experts: int, +) -> Buffer: + # NOTES: the low-latency mode will consume much more space than the normal mode + # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size + # in the decoding engine) should be less than 256 + global _buffer + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts + ) + + # Allocate a buffer if not existed or not enough buffer size + if ( + _buffer is None + or _buffer.group != group + or not _buffer.low_latency_mode + or _buffer.num_rdma_bytes < num_rdma_bytes + ): + # NOTES: for the best performance, the QP number **must** be equal to the + # number of the local experts + assert num_experts % group.size() == 0 + _buffer = Buffer( + group, + 0, + num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_experts // group.size(), + ) + return _buffer + + +def low_latency_dispatch( + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, +): + global _buffer + + # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer + # status once you replay) + recv_hidden_states, recv_expert_count, handle, event, hook = ( + _buffer.low_latency_dispatch( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + async_finish=False, + return_recv_hook=True, + ) + ) + + # NOTES: the actual tensor will not be received only if you call `hook()`, + # it is useful for double-batch overlapping, but **without any SM occupation** + # If you don't want to overlap, please set `return_recv_hook=False` + # Later, you can use our GEMM library to do the computation with this specific + # format + return recv_hidden_states, recv_expert_count, handle, event, hook + + +def low_latency_combine( + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: tuple, +): + global _buffer + + # Do MoE combine, compatible with CUDA graph (but you may restore some buffer + # status once you replay) + combined_hidden_states, event_overlap, hook = _buffer.low_latency_combine( + hidden_states, + topk_idx, + topk_weights, + handle, + async_finish=False, + return_recv_hook=True, + ) + + # NOTES: the same behavior as described in the dispatch kernel + return combined_hidden_states, event_overlap, hook + + +if __name__ == "__main__": + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + ) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + run_high_throughput()