diff --git a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py index b496733f23fd5..0157a2ca7d28d 100644 --- a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py +++ b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # torchrun --nproc_per_node=2 vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py # noqa: E501 + +# type: ignore import os import torch @@ -131,12 +133,13 @@ if __name__ == "__main__": group = dist.group.WORLD num_experts = 8 local_batch_size = 4 + batch_size = local_batch_size * world_size hidden_size = 128 local_num_experts = num_experts // group.size() x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16) hidden_bytes = get_hidden_bytes(x) get_buffer(group, hidden_bytes) - topk = 4 + topk = 2 expert_weights = torch.randn( local_batch_size, @@ -161,16 +164,11 @@ if __name__ == "__main__": num_experts, ) # print(f"rank {rank} recv_x: {recv_x.shape=}") - output_recv_topk_idx = recv_topk_idx + torch.where( + recv_topk_global_idx = recv_topk_idx + torch.where( recv_topk_idx == -1, 0, rank * local_num_experts, ) - print(f"rank {rank} recv_topk_idx: {recv_topk_idx.shape=} {output_recv_topk_idx}") - # print( - # f"rank {rank} recv_topk_weights: {recv_topk_weights.shape=} " - # f"{recv_topk_weights}" - # ) # Dispatch naive all_x = [torch.empty_like(x) for _ in range(world_size)] @@ -185,5 +183,34 @@ if __name__ == "__main__": all_topk_idx = torch.cat(all_topk_idx, dim=0) all_topk_weights = torch.cat(all_topk_weights, dim=0) - assert isinstance(all_topk_idx, torch.Tensor) - print(f"rank {rank} all_topk_idx: {all_topk_idx.shape=} {all_topk_idx}") + expert_range_start = rank * local_num_experts + expert_range_end = (rank + 1) * local_num_experts + recv_i = -1 + expert_inputs_ground_truth = [] + # Verification + for i in range(batch_size): + activated_on_this_rank = False + for j in range(topk): + if expert_range_start <= all_topk_idx[i, j] < expert_range_end: + if not activated_on_this_rank: + activated_on_this_rank = True + recv_i += 1 + for prev in range(j): + # Assert previous tokens are not activated on this rank + assert recv_topk_idx[recv_i, prev] == -1 + assert recv_topk_weights[recv_i, prev] == 0.0, ( + f"{recv_topk_weights[recv_i, prev]=}" + ) + expert_inputs_ground_truth.append(all_x[i]) + + assert ( + recv_topk_idx[recv_i, j] == all_topk_idx[i, j] - expert_range_start + ) + assert recv_topk_weights[recv_i, j] == all_topk_weights[i, j] + else: + if activated_on_this_rank: + assert recv_topk_idx[recv_i, j] == -1 + assert recv_topk_weights[recv_i, j] == 0.0 + + expert_inputs_ground_truth = torch.stack(expert_inputs_ground_truth, dim=0) + torch.testing.assert_close(expert_inputs_ground_truth, recv_x, atol=0.0, rtol=0.0)