[FEAT] Refactor ROPE into module (#22192)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-08-04 19:12:16 -07:00 committed by GitHub
parent f4f4e7ef27
commit 6ad6b8e115
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 2111 additions and 1967 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Rotary Positional Embeddings."""
from typing import Any, Optional
import torch
from .base import RotaryEmbedding
from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
from .dual_chunk_rope import DualChunkRotaryEmbedding
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from .linear_scaling_rope import LinearScalingRotaryEmbedding
from .llama3_rope import Llama3RotaryEmbedding
from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dual_chunk_attention_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
**extra_kwargs)
elif not rope_scaling:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic":
if "alpha" in rope_scaling:
scaling_alpha = rope_scaling["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_alpha, dtype)
elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
else:
raise ValueError("Dynamic rope scaling must contain either "
"'alpha' or 'factor' field")
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb

View File

@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Rotary Positional Embeddings Base Class."""
from typing import Optional
import torch
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm import _custom_ops as ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if key is None:
# XPU kernel doesn't support key=None so fall back to native impl
# TODO(sarckk): add support for optional key in
# ipex.llm.functional.rotary_embedding_batched
return self.forward_native(positions, query, key, offsets)
else:
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim, offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_neuron(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
def _apply_rotary_emb_neuron(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
if offsets is not None:
positions = positions + offsets
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
if self.rotary_dim == self.head_size:
query = apply_rotary_emb_dispatch(query, cos, sin,
self.is_neox_style)
query = query.reshape(query_shape)
if key is not None:
key = apply_rotary_emb_dispatch(key, cos, sin,
self.is_neox_style)
key = key.reshape(key_shape)
else:
head_size = query.shape[-1]
query_reshaped = query.view(-1, head_size)
query_pass = query_reshaped[:, self.rotary_dim:].view(
*query.shape[:-1], head_size - self.rotary_dim)
query_rot = query_reshaped[:, :self.rotary_dim].view(
*query.shape[:-1], self.rotary_dim)
query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass),
dim=-1).reshape(query_shape)
if key is not None:
key_reshaped = key.view(-1, head_size)
key_pass = key_reshaped[:, self.rotary_dim:].view(
*key.shape[:-1], head_size - self.rotary_dim)
key_rot = key_reshaped[:, :self.rotary_dim].view(
*key.shape[:-1], self.rotary_dim)
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s

View File

@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
from vllm.platforms import current_platform
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
# common functions
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else:
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
# yarn functions
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(num_rotations: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> float:
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))
# Find dim range bounds based on rotations
def yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> tuple[int, int]:
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype) -> torch.Tensor:
if low == high:
high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
from vllm.platforms import current_platform
from .base import RotaryEmbedding
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
yarn_linear_ramp_mask)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
torch.arange(0,
self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key

View File

@ -0,0 +1,188 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.custom_op import CustomOp
from .common import rotate_gptj, rotate_neox
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
chunk_size: int,
local_size: int,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.chunk_size = chunk_size
self.local_size = local_size
self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache,
q_inter_cache) = self._compute_cos_sin_cache()
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
self.register_buffer("cos_sin_qc_no_clamp_cache",
qc_no_clamp_cache,
persistent=False)
self.register_buffer("cos_sin_q_inter_cache",
q_inter_cache,
persistent=False)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
chunk_len = self.chunk_size - self.local_size
q_t = torch.arange(chunk_len, dtype=torch.float)
qc_t = (torch.arange(chunk_len, dtype=torch.float) +
chunk_len).clamp(max=self.chunk_size)
k_t = torch.arange(self.max_position_embeddings,
dtype=torch.float) % chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t = torch.arange(chunk_len,
dtype=torch.float) + self.chunk_size
q_freqs = torch.outer(q_t, inv_freq)
qc_freqs = torch.outer(qc_t, inv_freq)
k_freqs = torch.outer(k_t, inv_freq)
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
q_cos = q_freqs.cos()
q_sin = q_freqs.sin()
qc_cos = qc_freqs.cos()
qc_sin = qc_freqs.sin()
k_cos = k_freqs.cos()
k_sin = k_freqs.sin()
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
q_inter_cos = q_inter_freqs.cos()
q_inter_sin = q_inter_freqs.sin()
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
else:
query_pass = None
key_pass = None
positions_with_offsets = (torch.add(positions, offsets)
if offsets is not None else positions)
key = self._apply_rotary_embedding(
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
chunk_len = self.chunk_size - self.local_size
query = self._apply_rotary_embedding(
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_succ = self._apply_rotary_embedding(
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter = self._apply_rotary_embedding(
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
query_rot, query_pass)
query_succ_critical = self._apply_rotary_embedding(
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter_critical = self._apply_rotary_embedding(
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
# merge query into one tensor to simplify the interfaces
query = torch.cat((
query,
query_succ,
query_inter,
query_succ_critical,
query_inter_critical,
),
dim=-1)
return query, key
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
if self.rotary_dim < self.head_size:
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
else:
hidden = hidden_rot
return hidden.flatten(-2).squeeze(0)
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
return s

View File

@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from .base import RotaryEmbedding
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK alpha.
Based on the original RotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_cos_sin_cache(self) -> torch.Tensor:
# For Hunyuan DynamicNTKAlphaRotaryEmbedding
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

View File

@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .base import RotaryEmbedding
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
base = self.base * (
(self.scaling_factor * max_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

View File

@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .base import RotaryEmbedding
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factors: Union[list[float], float],
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: list[float] = scaling_factors # noqa
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
# Lazy initialized.
self._scaling_factor_to_offset: dict[float, int]
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list: list[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: list[int] = []
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float)
t = t / scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0)
@property
def scaling_factor_to_offset(self) -> dict[float, int]:
return self._scaling_factor_to_offset

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
from .base import RotaryEmbedding
class Llama3RotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
wave_len = 2 * math.pi / inv_freqs
if self.low_freq_factor != self.high_freq_factor:
smooth = (self.orig_max_position / wave_len - self.low_freq_factor
) / (self.high_freq_factor - self.low_freq_factor)
else:
smooth = 0
new_freqs = torch.where(
wave_len < high_freq_wavelen,
inv_freqs,
torch.where(
wave_len > low_freq_wavelen,
inv_freqs / self.scaling_factor,
(1 - smooth) * inv_freqs / self.scaling_factor +
smooth * inv_freqs,
),
)
return new_freqs

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
from .base import RotaryEmbedding
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
):
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
return inv_freqs
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches = self.max_position_embeddings
img_idx = torch.arange(num_patches,
dtype=torch.int32) \
.reshape(num_patches, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
num_patches_single_dim = int(math.sqrt(num_patches))
frequencies_x = img_idx % num_patches_single_dim
frequencies_y = img_idx // num_patches_single_dim
freqs_x = ((frequencies_x + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs_y = ((frequencies_y + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y],
dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
cache = torch.view_as_complex(
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
return cache
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(
*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)

View File

@ -0,0 +1,670 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Optional, Union
import numpy as np
import torch
from transformers import PretrainedConfig
from .base import RotaryEmbedding
from .common import apply_rotary_emb_dispatch
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
base, is_neox_style, dtype)
self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat([
m[i]
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i]
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@classmethod
def get_input_positions(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[list[float]],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]:
"""Get mrope input positions and delta value."""
image_grid_thw = [] if image_grid_thw is None else image_grid_thw
video_grid_thw = [] if video_grid_thw is None else video_grid_thw
second_per_grid_ts = [] if second_per_grid_ts is None else \
second_per_grid_ts
llm_positions, mrope_position_delta = \
cls.get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
return llm_positions.tolist(), mrope_position_delta
@classmethod
def get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
elif hf_config.model_type in ["glm4v", "glm4v_moe"]:
return cls._glm4v_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)
@classmethod
def _glm4v_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _omni_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config = hf_config.thinker_config
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if src_item[idx] not in [
audio_token_id, video_token_id, image_token_id
]:
if use_audio_in_video and idx > 0:
if src_item[idx] == vision_end_token_id and \
src_item[idx - 1] == audio_end_token_id:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif src_item[idx] == audio_start_token_id and \
src_item[idx - 1] == vision_start_token_id:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx],
dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: list[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
new_src_item.extend([video_token_id] *
vision_ntoken_per_chunk)
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_chunk,
grid_hs, grid_ws).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len) * [audio_token_id])
audio_start_idx = start_idx if len(
audio_llm_pos_ids_list
) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
if min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx).split(1,
dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id])
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = torch.cat(llm_pos_ids_list,
dim=1).max() + 1 - len(src_item)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@staticmethod
def _get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: list[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
len(t_index), -1, llm_grid_w).flatten())
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
len(t_index), llm_grid_h, -1).flatten())
t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> list[list[int]]:
ranges: list[list[int]] = [[]
for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> list[list[int]]:
return [
list(
range(context_len + mrope_position_delta,
seq_len + mrope_position_delta)) for _ in range(3)
]
@staticmethod
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
mrope_position_delta: int,
context_len: int, num_new_tokens: int):
values = np.arange(mrope_position_delta + context_len,
mrope_position_delta + context_len + num_new_tokens,
dtype=out.dtype)
out[:, out_offset:out_offset + num_new_tokens] = values
@classmethod
def omni_get_updates_use_audio_in_video(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[list[int], torch.Tensor],
video_second_per_grid_t: float,
) -> list[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id = thinker_config.audio_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
grid_t = video_grid_thw[0]
grid_h = video_grid_thw[1]
grid_w = video_grid_thw[2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) * video_second_per_grid_t *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
updates = [audio_start_token_id]
added_audio_len = 0
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
spatial_merge_size**2)
updates.extend([video_token_id] * vision_ntoken_per_chunk)
audio_chunk_size = min(t_ntoken_per_chunk,
audio_len - added_audio_len)
updates.extend(audio_chunk_size * [audio_token_id])
added_audio_len += audio_chunk_size
if added_audio_len < audio_len:
updates.extend((audio_len - added_audio_len) * [audio_token_id])
updates.extend([audio_end_token_id])
return updates

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from .base import RotaryEmbedding
class NTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with fixed and mixed NTK scaling.
https://kexue.fm/archives/9706 """
def __init__(self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
mixed_b: Optional[float] = None) -> None:
self.scaling_factor = scaling_factor
self.mixed_b = mixed_b
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
inv_freq = super()._compute_inv_freq(base)
if self.mixed_b is None:
inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
else:
a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
2)**self.mixed_b
lambda_1_m = (a * torch.arange(
1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
inv_freq = inv_freq / lambda_1_m
return inv_freq

View File

@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
import torch.nn as nn
from .common import rotate_neox
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
short_factor: list[float],
long_factor: list[float],
short_mscale: Optional[float] = None,
long_mscale: Optional[float] = None,
):
super().__init__()
if is_neox_style is False:
raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self.rotary_dim = rotary_dim
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
self.short_factor = short_factor
self.long_factor = long_factor
scale = self.max_position_embeddings / \
self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(
1 + math.log(scale) /
math.log(self.original_max_position_embeddings))
if short_mscale is None:
short_mscale = scaling_factor
if long_mscale is None:
long_mscale = scaling_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(dtype)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(dtype)
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
return inv_freq
def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: list[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
k = self.original_max_position_embeddings
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = query_rot * cos + rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = key_rot * cos + rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)
return query.flatten(-2), key.flatten(-2)

View File

@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from .base import RotaryEmbedding
from .common import (yarn_find_correction_range, yarn_get_mscale,
yarn_linear_ramp_mask)
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache