fix rmsnorm and act_quant_kernel

This commit is contained in:
youkaichao 2025-08-27 17:12:13 +08:00
parent 82f6008c8c
commit adecc0efbe
2 changed files with 8 additions and 2 deletions

View File

@ -23,7 +23,8 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
amax = tl.max(tl.abs(x), 1e-4)
amax = tl.max(tl.abs(x)) # reduction
amax = tl.maximum(amax, 1e-4) # clamp to 1e-4
s = amax / 448.
if scale_fmt == "ue8m0":
exp = tl.math.ceil(tl.math.log2(s))

View File

@ -291,7 +291,12 @@ class RMSNorm(nn.Module):
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
dtype = x.dtype
# make sure rms norm is computed in fp32
x = x.to(torch.float32)
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: