mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-09 04:44:28 +08:00
fix rmsnorm and act_quant_kernel
This commit is contained in:
parent
82f6008c8c
commit
adecc0efbe
@ -23,7 +23,8 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t
|
|||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||||
amax = tl.max(tl.abs(x), 1e-4)
|
amax = tl.max(tl.abs(x)) # reduction
|
||||||
|
amax = tl.maximum(amax, 1e-4) # clamp to 1e-4
|
||||||
s = amax / 448.
|
s = amax / 448.
|
||||||
if scale_fmt == "ue8m0":
|
if scale_fmt == "ue8m0":
|
||||||
exp = tl.math.ceil(tl.math.log2(s))
|
exp = tl.math.ceil(tl.math.log2(s))
|
||||||
|
|||||||
@ -291,7 +291,12 @@ class RMSNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Normalized tensor with the same shape as input.
|
torch.Tensor: Normalized tensor with the same shape as input.
|
||||||
"""
|
"""
|
||||||
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
dtype = x.dtype
|
||||||
|
# make sure rms norm is computed in fp32
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
var = x.pow(2).mean(-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(var + self.eps)
|
||||||
|
return (self.weight * x).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user