refactor and add low latency code

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-10-20 18:08:57 -07:00
parent e3e2bb3865
commit 48dcc72d7e

View File

@ -139,15 +139,9 @@ def combine_forward(
return combined_x, event
if __name__ == "__main__":
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
)
def run_high_throughput():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
torch.cuda.set_device(local_rank)
group = dist.group.WORLD
num_experts = 8
@ -182,11 +176,6 @@ if __name__ == "__main__":
topk_weights,
num_experts,
)
recv_topk_global_idx = recv_topk_idx + torch.where(
recv_topk_idx == -1,
0,
rank * local_num_experts,
)
# Dispatch naive
all_x = [torch.empty_like(x) for _ in range(world_size)]
@ -249,3 +238,100 @@ if __name__ == "__main__":
):
appear_count += 1.0
torch.testing.assert_close(combined_results[i], x[i] * appear_count)
# You may call this function at the framework initialization
def low_latency_get_buffer(
group: dist.ProcessGroup,
num_max_dispatch_tokens_per_rank: int,
hidden: int,
num_experts: int,
) -> Buffer:
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size
# in the decoding engine) should be less than 256
global _buffer
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
)
# Allocate a buffer if not existed or not enough buffer size
if (
_buffer is None
or _buffer.group != group
or not _buffer.low_latency_mode
or _buffer.num_rdma_bytes < num_rdma_bytes
):
# NOTES: for the best performance, the QP number **must** be equal to the
# number of the local experts
assert num_experts % group.size() == 0
_buffer = Buffer(
group,
0,
num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // group.size(),
)
return _buffer
def low_latency_dispatch(
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int,
):
global _buffer
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer
# status once you replay)
recv_hidden_states, recv_expert_count, handle, event, hook = (
_buffer.low_latency_dispatch(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
async_finish=False,
return_recv_hook=True,
)
)
# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific
# format
return recv_hidden_states, recv_expert_count, handle, event, hook
def low_latency_combine(
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
handle: tuple,
):
global _buffer
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer
# status once you replay)
combined_hidden_states, event_overlap, hook = _buffer.low_latency_combine(
hidden_states,
topk_idx,
topk_weights,
handle,
async_finish=False,
return_recv_hook=True,
)
# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states, event_overlap, hook
if __name__ == "__main__":
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
run_high_throughput()