From 92247c522e216f9d010db1c648dc783dbf141704 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 21 May 2025 01:37:08 -0400 Subject: [PATCH] [Bug] Fix moe_sum signature (#18440) Signed-off-by: Bill Nell --- csrc/moe/torch_bindings.cpp | 2 +- tests/kernels/moe/test_moe.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 810026d034c0..05f515e2e783 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Calculate the result of moe by summing up the partial results // from all selected experts. - m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.impl("moe_sum", torch::kCUDA, &moe_sum); // Aligning the number of tokens to be processed by each expert such diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 43ddc79fcb81..9a8ac242af79 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck(): opcheck(torch.ops._moe_C.moe_align_block_size, (topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)) + + +@pytest.mark.parametrize("m", [1, 33, 222, 1024 * 128]) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): + input = torch.randn((m, topk, k), device="cuda", dtype=dtype) + actual = torch.empty((m, k), device="cuda", dtype=dtype) + + expected = input.sum(dim=1) + torch.ops._moe_C.moe_sum(input, actual) + + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) + + opcheck(torch.ops._moe_C.moe_sum, (input, actual))