mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:45:52 +08:00
[Bug] Fix moe_sum signature (#18440)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
0c15c2e486
commit
92247c522e
@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
// Calculate the result of moe by summing up the partial results
|
||||
// from all selected experts.
|
||||
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
|
||||
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
|
||||
m.impl("moe_sum", torch::kCUDA, &moe_sum);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
|
||||
@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck():
|
||||
opcheck(torch.ops._moe_C.moe_align_block_size,
|
||||
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
||||
num_tokens_post_pad))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 222, 1024 * 128])
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
||||
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
|
||||
actual = torch.empty((m, k), device="cuda", dtype=dtype)
|
||||
|
||||
expected = input.sum(dim=1)
|
||||
torch.ops._moe_C.moe_sum(input, actual)
|
||||
|
||||
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
|
||||
|
||||
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user