tolerances

Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
Tyler Michael Smith 2025-06-20 19:47:25 +00:00
parent 21ffc7353a
commit b4f17e12a4

View File

@ -51,14 +51,13 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2
# Compute reference scales and quantized output
ref_s = torch.empty((E, T, H // group_size),
dtype=torch.float32,
device="cuda")
ref_q = torch.empty((E, T, H), dtype=torch.float8_e4m3fn, device="cuda")
# Compute reference scales and quantized output, skipping padded tokens
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, H // group_size),
dtype=torch.float32,
device="cuda")
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
for t in range(nt):
data = merged[e, t]
data_grp = data.view(H // group_size, group_size)
@ -69,14 +68,16 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
clamped = scaled.clamp(fp8_min, fp8_max)
q = clamped.to(torch.float8_e4m3fn)
ref_s[e, t] = scale
ref_q[e, t] = q
ref_s[t] = scale
ref_q[t] = q
# Compare scales and quantized outputs for valid tokens only
for e in range(E):
nt = tokens_per_expert[e].item()
torch.testing.assert_close(y_s[e, :nt], ref_s[e, :nt])
y_se = y_s[e]
y_qe = y_q[e]
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
torch.testing.assert_close(
y_q[e, :nt].to(torch.float32),
ref_q[e, :nt].to(torch.float32),
y_qe[:nt].to(torch.float32),
ref_q[:nt].to(torch.float32),
atol=2,
rtol=2e-1,
)