mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-16 00:04:31 +08:00
Update model.py
This commit is contained in:
parent
c21638c56c
commit
b2253d1807
@ -290,106 +290,6 @@ class RMSNorm(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Precomputes frequency-based complex exponential values for rotary positional embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args (ModelArgs): Model arguments containing positional embedding parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Precomputed complex exponential values for positional embeddings.
|
|
||||||
"""
|
|
||||||
dim = args.qk_rope_head_dim
|
|
||||||
seqlen = args.max_seq_len
|
|
||||||
beta_fast = args.beta_fast
|
|
||||||
beta_slow = args.beta_slow
|
|
||||||
base = args.rope_theta
|
|
||||||
factor = args.rope_factor
|
|
||||||
|
|
||||||
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
|
||||||
"""
|
|
||||||
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_rotations (float): Number of rotations to compute the correction for.
|
|
||||||
dim (int): Dimensionality of the embedding space.
|
|
||||||
base (float): Base value for the exponential computation.
|
|
||||||
max_seq_len (int): Maximum sequence length.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The correction dimension based on the input parameters.
|
|
||||||
"""
|
|
||||||
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
|
||||||
|
|
||||||
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
|
||||||
"""
|
|
||||||
Computes the range of correction dimensions for rotary positional embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
low_rot (float): Lower bound for the number of rotations.
|
|
||||||
high_rot (float): Upper bound for the number of rotations.
|
|
||||||
dim (int): Dimensionality of the embedding space.
|
|
||||||
base (float): Base value for the exponential computation.
|
|
||||||
max_seq_len (int): Maximum sequence length.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
|
|
||||||
"""
|
|
||||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
|
||||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
|
||||||
return max(low, 0), min(high, dim-1)
|
|
||||||
|
|
||||||
def linear_ramp_factor(min, max, dim):
|
|
||||||
"""
|
|
||||||
Computes a linear ramp function used to smooth values between a minimum and maximum range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min (float): Minimum value for the ramp function.
|
|
||||||
max (float): Maximum value for the ramp function.
|
|
||||||
dim (int): Dimensionality of the ramp tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
|
|
||||||
clamped to the range [0, 1].
|
|
||||||
"""
|
|
||||||
if min == max:
|
|
||||||
max += 0.001
|
|
||||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
|
||||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
||||||
return ramp_func
|
|
||||||
|
|
||||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
|
||||||
if seqlen > args.original_seq_len:
|
|
||||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
|
|
||||||
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
|
||||||
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
|
||||||
|
|
||||||
t = torch.arange(seqlen)
|
|
||||||
freqs = torch.outer(t, freqs)
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Applies rotary positional embeddings to the input tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Input tensor with positional embeddings to be applied.
|
|
||||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Tensor with rotary embeddings applied.
|
|
||||||
"""
|
|
||||||
dtype = x.dtype
|
|
||||||
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
|
|
||||||
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
|
||||||
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
|
||||||
return y.to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class MLA(nn.Module):
|
class MLA(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multi-Head Latent Attention (MLA) Layer.
|
Multi-Head Latent Attention (MLA) Layer.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user