mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-09 21:04:36 +08:00
Compare commits
2 Commits
cda950e6c2
...
e867d890db
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e867d890db | ||
|
|
9b4e9788e4 |
@ -291,12 +291,7 @@ class RMSNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Normalized tensor with the same shape as input.
|
torch.Tensor: Normalized tensor with the same shape as input.
|
||||||
"""
|
"""
|
||||||
dtype = x.dtype
|
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
||||||
# make sure rms norm is computed in fp32
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
var = x.pow(2).mean(-1, keepdim=True)
|
|
||||||
x = x * torch.rsqrt(var + self.eps)
|
|
||||||
return (self.weight * x).to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user