diff --git a/inference/model.py b/inference/model.py index c143e97..9b0aa40 100644 --- a/inference/model.py +++ b/inference/model.py @@ -290,106 +290,6 @@ class RMSNorm(nn.Module): """ return F.rms_norm(x, (self.dim,), self.weight, self.eps) - -def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: - """ - Precomputes frequency-based complex exponential values for rotary positional embeddings. - - Args: - args (ModelArgs): Model arguments containing positional embedding parameters. - - Returns: - torch.Tensor: Precomputed complex exponential values for positional embeddings. - """ - dim = args.qk_rope_head_dim - seqlen = args.max_seq_len - beta_fast = args.beta_fast - beta_slow = args.beta_slow - base = args.rope_theta - factor = args.rope_factor - - def find_correction_dim(num_rotations, dim, base, max_seq_len): - """ - Computes the correction dimension for a given number of rotations in the rotary positional embedding. - - Args: - num_rotations (float): Number of rotations to compute the correction for. - dim (int): Dimensionality of the embedding space. - base (float): Base value for the exponential computation. - max_seq_len (int): Maximum sequence length. - - Returns: - float: The correction dimension based on the input parameters. - """ - return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): - """ - Computes the range of correction dimensions for rotary positional embeddings. - - Args: - low_rot (float): Lower bound for the number of rotations. - high_rot (float): Upper bound for the number of rotations. - dim (int): Dimensionality of the embedding space. - base (float): Base value for the exponential computation. - max_seq_len (int): Maximum sequence length. - - Returns: - Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. - """ - low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) - return max(low, 0), min(high, dim-1) - - def linear_ramp_factor(min, max, dim): - """ - Computes a linear ramp function used to smooth values between a minimum and maximum range. - - Args: - min (float): Minimum value for the ramp function. - max (float): Maximum value for the ramp function. - dim (int): Dimensionality of the ramp tensor. - - Returns: - torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, - clamped to the range [0, 1]. - """ - if min == max: - max += 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if seqlen > args.original_seq_len: - low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) - smooth = 1 - linear_ramp_factor(low, high, dim // 2) - freqs = freqs / factor * (1 - smooth) + freqs * smooth - - t = torch.arange(seqlen) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - return freqs_cis - - -def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """ - Applies rotary positional embeddings to the input tensor. - - Args: - x (torch.Tensor): Input tensor with positional embeddings to be applied. - freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. - - Returns: - torch.Tensor: Tensor with rotary embeddings applied. - """ - dtype = x.dtype - x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) - y = torch.view_as_real(x * freqs_cis).flatten(3) - return y.to(dtype) - - class MLA(nn.Module): """ Multi-Head Latent Attention (MLA) Layer.