complicated assert for correctness check

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-10-19 23:30:28 -07:00
parent dcf059ab84
commit 177f5d757f

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# torchrun --nproc_per_node=2 vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py # noqa: E501
# type: ignore
import os
import torch
@ -131,12 +133,13 @@ if __name__ == "__main__":
group = dist.group.WORLD
num_experts = 8
local_batch_size = 4
batch_size = local_batch_size * world_size
hidden_size = 128
local_num_experts = num_experts // group.size()
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
hidden_bytes = get_hidden_bytes(x)
get_buffer(group, hidden_bytes)
topk = 4
topk = 2
expert_weights = torch.randn(
local_batch_size,
@ -161,16 +164,11 @@ if __name__ == "__main__":
num_experts,
)
# print(f"rank {rank} recv_x: {recv_x.shape=}")
output_recv_topk_idx = recv_topk_idx + torch.where(
recv_topk_global_idx = recv_topk_idx + torch.where(
recv_topk_idx == -1,
0,
rank * local_num_experts,
)
print(f"rank {rank} recv_topk_idx: {recv_topk_idx.shape=} {output_recv_topk_idx}")
# print(
# f"rank {rank} recv_topk_weights: {recv_topk_weights.shape=} "
# f"{recv_topk_weights}"
# )
# Dispatch naive
all_x = [torch.empty_like(x) for _ in range(world_size)]
@ -185,5 +183,34 @@ if __name__ == "__main__":
all_topk_idx = torch.cat(all_topk_idx, dim=0)
all_topk_weights = torch.cat(all_topk_weights, dim=0)
assert isinstance(all_topk_idx, torch.Tensor)
print(f"rank {rank} all_topk_idx: {all_topk_idx.shape=} {all_topk_idx}")
expert_range_start = rank * local_num_experts
expert_range_end = (rank + 1) * local_num_experts
recv_i = -1
expert_inputs_ground_truth = []
# Verification
for i in range(batch_size):
activated_on_this_rank = False
for j in range(topk):
if expert_range_start <= all_topk_idx[i, j] < expert_range_end:
if not activated_on_this_rank:
activated_on_this_rank = True
recv_i += 1
for prev in range(j):
# Assert previous tokens are not activated on this rank
assert recv_topk_idx[recv_i, prev] == -1
assert recv_topk_weights[recv_i, prev] == 0.0, (
f"{recv_topk_weights[recv_i, prev]=}"
)
expert_inputs_ground_truth.append(all_x[i])
assert (
recv_topk_idx[recv_i, j] == all_topk_idx[i, j] - expert_range_start
)
assert recv_topk_weights[recv_i, j] == all_topk_weights[i, j]
else:
if activated_on_this_rank:
assert recv_topk_idx[recv_i, j] == -1
assert recv_topk_weights[recv_i, j] == 0.0
expert_inputs_ground_truth = torch.stack(expert_inputs_ground_truth, dim=0)
torch.testing.assert_close(expert_inputs_ground_truth, recv_x, atol=0.0, rtol=0.0)