Compare commits

..

2 Commits

Author SHA1 Message Date
Nripesh Niketan
e867d890db
Merge a5336884cfa952634f355549643dcb781eaec872 into 9b4e9788e4a3a731f7567338ed15d3ec549ce03b 2025-09-01 00:53:48 +01:00
GeeeekExplorer
9b4e9788e4 Merge pull request #969 from youkaichao/rmsnorm
act_quant_kernel
2025-08-28 11:24:26 +08:00

View File

@ -291,12 +291,7 @@ class RMSNorm(nn.Module):
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
dtype = x.dtype
# make sure rms norm is computed in fp32
x = x.to(torch.float32)
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: