mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 02:57:05 +08:00
tolerances
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
parent
21ffc7353a
commit
b4f17e12a4
@ -51,14 +51,13 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
||||
silu_x = y1 * torch.sigmoid(y1)
|
||||
merged = silu_x * y2
|
||||
|
||||
# Compute reference scales and quantized output
|
||||
ref_s = torch.empty((E, T, H // group_size),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
ref_q = torch.empty((E, T, H), dtype=torch.float8_e4m3fn, device="cuda")
|
||||
# Compute reference scales and quantized output, skipping padded tokens
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
ref_s = torch.empty((T, H // group_size),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
|
||||
for t in range(nt):
|
||||
data = merged[e, t]
|
||||
data_grp = data.view(H // group_size, group_size)
|
||||
@ -69,14 +68,16 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
||||
clamped = scaled.clamp(fp8_min, fp8_max)
|
||||
q = clamped.to(torch.float8_e4m3fn)
|
||||
|
||||
ref_s[e, t] = scale
|
||||
ref_q[e, t] = q
|
||||
ref_s[t] = scale
|
||||
ref_q[t] = q
|
||||
|
||||
# Compare scales and quantized outputs for valid tokens only
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
torch.testing.assert_close(y_s[e, :nt], ref_s[e, :nt])
|
||||
y_se = y_s[e]
|
||||
y_qe = y_q[e]
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
|
||||
torch.testing.assert_close(
|
||||
y_q[e, :nt].to(torch.float32),
|
||||
ref_q[e, :nt].to(torch.float32),
|
||||
y_qe[:nt].to(torch.float32),
|
||||
ref_q[:nt].to(torch.float32),
|
||||
atol=2,
|
||||
rtol=2e-1,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user