mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-09 12:54:33 +08:00
Update model.py
This commit is contained in:
parent
c21638c56c
commit
b2253d1807
@ -290,106 +290,6 @@ class RMSNorm(nn.Module):
|
||||
"""
|
||||
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
||||
|
||||
|
||||
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
"""
|
||||
Precomputes frequency-based complex exponential values for rotary positional embeddings.
|
||||
|
||||
Args:
|
||||
args (ModelArgs): Model arguments containing positional embedding parameters.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed complex exponential values for positional embeddings.
|
||||
"""
|
||||
dim = args.qk_rope_head_dim
|
||||
seqlen = args.max_seq_len
|
||||
beta_fast = args.beta_fast
|
||||
beta_slow = args.beta_slow
|
||||
base = args.rope_theta
|
||||
factor = args.rope_factor
|
||||
|
||||
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
||||
"""
|
||||
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
|
||||
|
||||
Args:
|
||||
num_rotations (float): Number of rotations to compute the correction for.
|
||||
dim (int): Dimensionality of the embedding space.
|
||||
base (float): Base value for the exponential computation.
|
||||
max_seq_len (int): Maximum sequence length.
|
||||
|
||||
Returns:
|
||||
float: The correction dimension based on the input parameters.
|
||||
"""
|
||||
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
||||
|
||||
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
||||
"""
|
||||
Computes the range of correction dimensions for rotary positional embeddings.
|
||||
|
||||
Args:
|
||||
low_rot (float): Lower bound for the number of rotations.
|
||||
high_rot (float): Upper bound for the number of rotations.
|
||||
dim (int): Dimensionality of the embedding space.
|
||||
base (float): Base value for the exponential computation.
|
||||
max_seq_len (int): Maximum sequence length.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
|
||||
"""
|
||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
||||
return max(low, 0), min(high, dim-1)
|
||||
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
"""
|
||||
Computes a linear ramp function used to smooth values between a minimum and maximum range.
|
||||
|
||||
Args:
|
||||
min (float): Minimum value for the ramp function.
|
||||
max (float): Maximum value for the ramp function.
|
||||
dim (int): Dimensionality of the ramp tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
|
||||
clamped to the range [0, 1].
|
||||
"""
|
||||
if min == max:
|
||||
max += 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
if seqlen > args.original_seq_len:
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
|
||||
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
||||
|
||||
t = torch.arange(seqlen)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies rotary positional embeddings to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with positional embeddings to be applied.
|
||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor with rotary embeddings applied.
|
||||
"""
|
||||
dtype = x.dtype
|
||||
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
||||
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return y.to(dtype)
|
||||
|
||||
|
||||
class MLA(nn.Module):
|
||||
"""
|
||||
Multi-Head Latent Attention (MLA) Layer.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user