add combine example

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-10-20 17:45:31 -07:00
parent 4e2abe99b7
commit e3e2bb3865
2 changed files with 45 additions and 9 deletions

View File

@ -120,6 +120,25 @@ def dispatch_forward(
)
def combine_forward(
x: torch.Tensor, handle: tuple, previous_event: EventOverlap | None = None
) -> tuple[torch.Tensor, EventOverlap]:
global _buffer
# Do MoE combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_x, _, event = _buffer.combine(
x,
handle,
async_finish=True,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
)
# For event management, please refer to the docs of the `EventOverlap` class
return combined_x, event
if __name__ == "__main__":
torch.distributed.init_process_group(
backend="nccl",
@ -134,7 +153,7 @@ if __name__ == "__main__":
num_experts = 8
local_batch_size = 4
batch_size = local_batch_size * world_size
hidden_size = 128
hidden_size = 32
local_num_experts = num_experts // group.size()
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
hidden_bytes = get_hidden_bytes(x)
@ -163,7 +182,6 @@ if __name__ == "__main__":
topk_weights,
num_experts,
)
# print(f"rank {rank} recv_x: {recv_x.shape=}")
recv_topk_global_idx = recv_topk_idx + torch.where(
recv_topk_idx == -1,
0,
@ -198,9 +216,7 @@ if __name__ == "__main__":
for prev in range(j):
# Assert previous tokens are not activated on this rank
assert recv_topk_idx[recv_i, prev] == -1
assert recv_topk_weights[recv_i, prev] == 0.0, (
f"{recv_topk_weights[recv_i, prev]=}"
)
assert recv_topk_weights[recv_i, prev] == 0.0
expert_inputs_ground_truth.append(all_x[i])
assert (
@ -214,3 +230,22 @@ if __name__ == "__main__":
expert_inputs_ground_truth = torch.stack(expert_inputs_ground_truth, dim=0)
torch.testing.assert_close(expert_inputs_ground_truth, recv_x, atol=0.0, rtol=0.0)
# NOTE: This dispatch result cannot be directly fed into the GEMM kernel, it
# needs another shuffuling/expansion process to cluster the input to each expert.
# The results to combine will be reduced on the local rank first.
# Combine
results = recv_x
combined_results, event = combine_forward(results, handle, event)
for i in range(local_batch_size):
appear_count = 0.0
for j in range(world_size):
expert_range_start = j * local_num_experts
expert_range_end = (j + 1) * local_num_experts
if any(
(expert_range_start <= topk_idx[i]) & (topk_idx[i] < expert_range_end)
):
appear_count += 1.0
torch.testing.assert_close(combined_results[i], x[i] * appear_count)

View File

@ -136,7 +136,7 @@ def run_batched_deepgemm_contiguous_bf16(
aligned_end = start + aligned_ms[i]
expert_ids[start:actual_end] = i
expert_ids[actual_end:aligned_end] = -1
reference_output[start:aligned_end] = x[start:aligned_end] @ weight[i].t()
reference_output[start:actual_end] = x[start:actual_end] @ weight[i].t()
start = aligned_end
output = torch.zeros(
@ -152,6 +152,7 @@ def run_batched_deepgemm_contiguous_bf16(
output,
expert_ids,
)
output = output * (expert_ids != -1).unsqueeze(1)
torch.testing.assert_close(output, reference_output)
@ -351,6 +352,6 @@ def run_triton_group_gemm_masked_bf16(
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
run_batched_deepgemm_contiguous_bf16(512, 8, 1024, 512)
run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
# run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
# run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
# run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)