mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 05:29:14 +08:00
complicated assert for correctness check
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
dcf059ab84
commit
177f5d757f
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# torchrun --nproc_per_node=2 vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py # noqa: E501
|
||||
|
||||
# type: ignore
|
||||
import os
|
||||
|
||||
import torch
|
||||
@ -131,12 +133,13 @@ if __name__ == "__main__":
|
||||
group = dist.group.WORLD
|
||||
num_experts = 8
|
||||
local_batch_size = 4
|
||||
batch_size = local_batch_size * world_size
|
||||
hidden_size = 128
|
||||
local_num_experts = num_experts // group.size()
|
||||
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
hidden_bytes = get_hidden_bytes(x)
|
||||
get_buffer(group, hidden_bytes)
|
||||
topk = 4
|
||||
topk = 2
|
||||
|
||||
expert_weights = torch.randn(
|
||||
local_batch_size,
|
||||
@ -161,16 +164,11 @@ if __name__ == "__main__":
|
||||
num_experts,
|
||||
)
|
||||
# print(f"rank {rank} recv_x: {recv_x.shape=}")
|
||||
output_recv_topk_idx = recv_topk_idx + torch.where(
|
||||
recv_topk_global_idx = recv_topk_idx + torch.where(
|
||||
recv_topk_idx == -1,
|
||||
0,
|
||||
rank * local_num_experts,
|
||||
)
|
||||
print(f"rank {rank} recv_topk_idx: {recv_topk_idx.shape=} {output_recv_topk_idx}")
|
||||
# print(
|
||||
# f"rank {rank} recv_topk_weights: {recv_topk_weights.shape=} "
|
||||
# f"{recv_topk_weights}"
|
||||
# )
|
||||
|
||||
# Dispatch naive
|
||||
all_x = [torch.empty_like(x) for _ in range(world_size)]
|
||||
@ -185,5 +183,34 @@ if __name__ == "__main__":
|
||||
all_topk_idx = torch.cat(all_topk_idx, dim=0)
|
||||
all_topk_weights = torch.cat(all_topk_weights, dim=0)
|
||||
|
||||
assert isinstance(all_topk_idx, torch.Tensor)
|
||||
print(f"rank {rank} all_topk_idx: {all_topk_idx.shape=} {all_topk_idx}")
|
||||
expert_range_start = rank * local_num_experts
|
||||
expert_range_end = (rank + 1) * local_num_experts
|
||||
recv_i = -1
|
||||
expert_inputs_ground_truth = []
|
||||
# Verification
|
||||
for i in range(batch_size):
|
||||
activated_on_this_rank = False
|
||||
for j in range(topk):
|
||||
if expert_range_start <= all_topk_idx[i, j] < expert_range_end:
|
||||
if not activated_on_this_rank:
|
||||
activated_on_this_rank = True
|
||||
recv_i += 1
|
||||
for prev in range(j):
|
||||
# Assert previous tokens are not activated on this rank
|
||||
assert recv_topk_idx[recv_i, prev] == -1
|
||||
assert recv_topk_weights[recv_i, prev] == 0.0, (
|
||||
f"{recv_topk_weights[recv_i, prev]=}"
|
||||
)
|
||||
expert_inputs_ground_truth.append(all_x[i])
|
||||
|
||||
assert (
|
||||
recv_topk_idx[recv_i, j] == all_topk_idx[i, j] - expert_range_start
|
||||
)
|
||||
assert recv_topk_weights[recv_i, j] == all_topk_weights[i, j]
|
||||
else:
|
||||
if activated_on_this_rank:
|
||||
assert recv_topk_idx[recv_i, j] == -1
|
||||
assert recv_topk_weights[recv_i, j] == 0.0
|
||||
|
||||
expert_inputs_ground_truth = torch.stack(expert_inputs_ground_truth, dim=0)
|
||||
torch.testing.assert_close(expert_inputs_ground_truth, recv_x, atol=0.0, rtol=0.0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user