diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 8c143c808cf86..b1dda531e6b30 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -136,16 +136,18 @@ def ref_impl( for e in range(num_experts): num_tokens = num_expert_tokens_cpu[e] if A.dtype == torch.torch.float8_e4m3fn: - C[e, :, :] = native_w8a8_block_matmul(A[e, :, :], - B[e].transpose(0, 1), - A_scale, - B_scale, - [1,1])#block_shape) + tmp = native_w8a8_block_matmul(A[e, :, :], + B[e].transpose(0, 1), + A_scale, + B_scale, + [1,1])#block_shape) + C[e, :num_tokens, :] = tmp[:num_tokens, :] else: C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) return C + @pytest.mark.parametrize("num_experts", [16, 32]) @pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512])