[Bug] Fix moe_sum signature (#18440)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-05-21 01:37:08 -04:00 committed by GitHub
parent 0c15c2e486
commit 92247c522e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 1 deletions

View File

@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
// Aligning the number of tokens to be processed by each expert such

View File

@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck():
opcheck(torch.ops._moe_C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))
@pytest.mark.parametrize("m", [1, 33, 222, 1024 * 128])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
actual = torch.empty((m, k), device="cuda", dtype=dtype)
expected = input.sum(dim=1)
torch.ops._moe_C.moe_sum(input, actual)
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
opcheck(torch.ops._moe_C.moe_sum, (input, actual))