From b4f17e12a444a90d21528c412d19f6d7488494ba Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 19:47:25 +0000 Subject: [PATCH] tolerances Signed-off-by: Tyler Michael Smith --- .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5cfb4266ff933..673a0aa367948 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -51,14 +51,13 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): silu_x = y1 * torch.sigmoid(y1) merged = silu_x * y2 - # Compute reference scales and quantized output - ref_s = torch.empty((E, T, H // group_size), - dtype=torch.float32, - device="cuda") - ref_q = torch.empty((E, T, H), dtype=torch.float8_e4m3fn, device="cuda") # Compute reference scales and quantized output, skipping padded tokens for e in range(E): nt = tokens_per_expert[e].item() + ref_s = torch.empty((T, H // group_size), + dtype=torch.float32, + device="cuda") + ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") for t in range(nt): data = merged[e, t] data_grp = data.view(H // group_size, group_size) @@ -69,14 +68,16 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): clamped = scaled.clamp(fp8_min, fp8_max) q = clamped.to(torch.float8_e4m3fn) - ref_s[e, t] = scale - ref_q[e, t] = q + ref_s[t] = scale + ref_q[t] = q - # Compare scales and quantized outputs for valid tokens only - for e in range(E): - nt = tokens_per_expert[e].item() - torch.testing.assert_close(y_s[e, :nt], ref_s[e, :nt]) + y_se = y_s[e] + y_qe = y_q[e] + + torch.testing.assert_close(y_se[:nt], ref_s[:nt]) torch.testing.assert_close( - y_q[e, :nt].to(torch.float32), - ref_q[e, :nt].to(torch.float32), + y_qe[:nt].to(torch.float32), + ref_q[:nt].to(torch.float32), + atol=2, + rtol=2e-1, )