mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
Support YaRN models (#1264)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com> Co-authored-by: Viktor Ferenczi <viktor@ferenczi.eu> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
555bdcc5a3
commit
9f669a9a7c
@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
const int token_idx = blockIdx.x;
|
||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
@ -30,7 +30,7 @@ void silu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@ -55,8 +55,8 @@ __global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
const int token_idx = blockIdx.x;
|
||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
@ -67,7 +67,7 @@ __global__ void activation_kernel(
|
||||
// Launch element-wise activation kernel.
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
int d = input.size(-1); \
|
||||
int num_tokens = input.numel() / d; \
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
|
||||
@ -84,7 +84,7 @@ void rotary_embedding(
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int num_tokens = query.numel() / query.size(-1);
|
||||
int64_t num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
|
||||
@ -390,6 +390,9 @@ def _get_and_verify_max_len(
|
||||
if rope_scaling is not None:
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "yarn":
|
||||
derived_max_model_len = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
derived_max_model_len *= scaling_factor
|
||||
|
||||
if max_model_len is None:
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||
RotaryEmbedding)
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding)
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@ -334,6 +334,19 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "yarn":
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor",
|
||||
"beta_fast", "beta_slow")
|
||||
}
|
||||
self.rotary_emb = YaRNScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, **extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rotary Positional Embeddings."""
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -167,3 +168,106 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def _yarn_find_correction_dim(num_rotations: int,
|
||||
dim: int,
|
||||
base: float = 10000,
|
||||
max_position_embeddings: int = 2048) -> float:
|
||||
return (dim * math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 *
|
||||
math.log(base))
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def _yarn_find_correction_range(low_rot: int,
|
||||
high_rot: int,
|
||||
dim: int,
|
||||
base: float = 10000,
|
||||
max_position_embeddings: int = 2048) -> int:
|
||||
low = math.floor(
|
||||
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(
|
||||
_yarn_find_correction_dim(high_rot, dim, base,
|
||||
max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
if low == high:
|
||||
high += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=dtype, device=device) -
|
||||
low) / (high - low)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
def _yarn_get_mscale(scale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: float = 32,
|
||||
beta_slow: float = 1,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(
|
||||
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2, dtype=torch.float,
|
||||
device="cuda")) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user