From adecc0efbe2fda18945734168fce6e0df0d804c3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 27 Aug 2025 17:12:13 +0800 Subject: [PATCH] fix rmsnorm and act_quant_kernel --- inference/kernel.py | 3 ++- inference/model.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/inference/kernel.py b/inference/kernel.py index 57832ad..22afc92 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -23,7 +23,8 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - amax = tl.max(tl.abs(x), 1e-4) + amax = tl.max(tl.abs(x)) # reduction + amax = tl.maximum(amax, 1e-4) # clamp to 1e-4 s = amax / 448. if scale_fmt == "ue8m0": exp = tl.math.ceil(tl.math.log2(s)) diff --git a/inference/model.py b/inference/model.py index 8868499..9ce8168 100644 --- a/inference/model.py +++ b/inference/model.py @@ -291,7 +291,12 @@ class RMSNorm(nn.Module): Returns: torch.Tensor: Normalized tensor with the same shape as input. """ - return F.rms_norm(x, (self.dim,), self.weight, self.eps) + dtype = x.dtype + # make sure rms norm is computed in fp32 + x = x.to(torch.float32) + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return (self.weight * x).to(dtype) def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: