diff --git a/inference/kernel.py b/inference/kernel.py index af69daa..e12ad48 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -23,7 +23,9 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448. + amax = tl.max(tl.abs(x)) + amax = tl.min(amax, 1e-4) + s = amax / 448. if scale_fmt == "ue8m0": exp = tl.math.ceil(tl.math.log2(s)) s = tl.math.exp2(exp)