mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 09:45:58 +08:00
55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
|
|
import torch
|
|
|
|
from .base import RotaryEmbedding
|
|
|
|
|
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
|
def __init__(
|
|
self,
|
|
head_size: int,
|
|
rotary_dim: int,
|
|
max_position_embeddings: int,
|
|
base: float,
|
|
is_neox_style: bool,
|
|
dtype: torch.dtype,
|
|
scaling_factor: float,
|
|
low_freq_factor: float,
|
|
high_freq_factor: float,
|
|
orig_max_position: int,
|
|
) -> None:
|
|
self.scaling_factor = scaling_factor
|
|
self.low_freq_factor = low_freq_factor
|
|
self.high_freq_factor = high_freq_factor
|
|
self.orig_max_position = orig_max_position
|
|
super().__init__(
|
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
|
)
|
|
|
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
|
inv_freqs = super()._compute_inv_freq(base)
|
|
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
|
|
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
|
|
|
|
wave_len = 2 * math.pi / inv_freqs
|
|
if self.low_freq_factor != self.high_freq_factor:
|
|
smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / (
|
|
self.high_freq_factor - self.low_freq_factor
|
|
)
|
|
else:
|
|
smooth = 0
|
|
new_freqs = torch.where(
|
|
wave_len < high_freq_wavelen,
|
|
inv_freqs,
|
|
torch.where(
|
|
wave_len > low_freq_wavelen,
|
|
inv_freqs / self.scaling_factor,
|
|
(1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs,
|
|
),
|
|
)
|
|
return new_freqs
|