Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-21 21:55:13 +00:00
parent 25ed6738d4
commit bbe888d033

View File

@ -136,16 +136,18 @@ def ref_impl(
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
if A.dtype == torch.torch.float8_e4m3fn:
C[e, :, :] = native_w8a8_block_matmul(A[e, :, :],
B[e].transpose(0, 1),
A_scale,
B_scale,
[1,1])#block_shape)
tmp = native_w8a8_block_matmul(A[e, :, :],
B[e].transpose(0, 1),
A_scale,
B_scale,
[1,1])#block_shape)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
else:
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])