From bbe888d03300e4e8f542e6e29b21d5001946a5b1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 21 May 2025 21:55:13 +0000 Subject: [PATCH] wip Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 8c143c808cf86..b1dda531e6b30 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -136,16 +136,18 @@ def ref_impl( for e in range(num_experts): num_tokens = num_expert_tokens_cpu[e] if A.dtype == torch.torch.float8_e4m3fn: - C[e, :, :] = native_w8a8_block_matmul(A[e, :, :], - B[e].transpose(0, 1), - A_scale, - B_scale, - [1,1])#block_shape) + tmp = native_w8a8_block_matmul(A[e, :, :], + B[e].transpose(0, 1), + A_scale, + B_scale, + [1,1])#block_shape) + C[e, :num_tokens, :] = tmp[:num_tokens, :] else: C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) return C + @pytest.mark.parametrize("num_experts", [16, 32]) @pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512])