low latency dispatch

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-10-20 21:49:09 -07:00
parent 48dcc72d7e
commit da26dce7b2

View File

@ -293,6 +293,7 @@ def low_latency_dispatch(
num_experts,
async_finish=False,
return_recv_hook=True,
use_fp8=False,
)
)
@ -327,6 +328,65 @@ def low_latency_combine(
return combined_hidden_states, event_overlap, hook
def run_low_latency():
world_size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", 0))
group = dist.group.WORLD
num_experts = 8
local_batch_size = 4
batch_size = local_batch_size * world_size
hidden_size = 2048
local_num_experts = num_experts // group.size()
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
topk = 2
num_max_dispatch_tokens_per_rank = local_batch_size
low_latency_get_buffer(
group, num_max_dispatch_tokens_per_rank, hidden_size, num_experts
)
expert_weights = torch.randn(
local_batch_size,
num_experts,
dtype=torch.float32,
device="cuda",
)
topk_weights, topk_idx = torch.topk(expert_weights, topk, dim=1)
recv_hidden_states, recv_expert_count, handle, event, hook = low_latency_dispatch(
x, topk_idx, num_max_dispatch_tokens_per_rank, num_experts
)
hook()
# Dispatch naive
all_x = [torch.empty_like(x) for _ in range(world_size)]
all_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)]
all_topk_weights = [torch.empty_like(topk_weights) for _ in range(world_size)]
dist.all_gather(all_x, x)
dist.all_gather(all_topk_idx, topk_idx)
dist.all_gather(all_topk_weights, topk_weights)
all_x = torch.cat(all_x, dim=0)
all_topk_idx = torch.cat(all_topk_idx, dim=0)
all_topk_weights = torch.cat(all_topk_weights, dim=0)
expert_tok_ids = [[] for _ in range(local_num_experts)]
expert_range_start = rank * local_num_experts
for i, cnt in enumerate(recv_expert_count):
for j in range(cnt):
for k in range(batch_size):
if torch.allclose(
recv_hidden_states[i, j], all_x[k], rtol=0.0, atol=0.0
):
expert_tok_ids[i].append(k)
assert i + expert_range_start in all_topk_idx[k]
break
assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist()
if __name__ == "__main__":
torch.distributed.init_process_group(
backend="nccl",
@ -335,3 +395,4 @@ if __name__ == "__main__":
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
run_high_throughput()
run_low_latency()