mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:05:01 +08:00
[feat] Support MRoPE + YaRN (#25384)
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
This commit is contained in:
parent
4741239db7
commit
fc97733da8
@ -153,11 +153,23 @@ def get_rope(
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow")
|
||||
}
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor, dtype,
|
||||
**extra_kwargs)
|
||||
if "mrope_section" in rope_scaling:
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved",
|
||||
False),
|
||||
scaling_factor=scaling_factor,
|
||||
**extra_kwargs)
|
||||
else:
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.triton_utils import tl, triton
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -213,7 +214,27 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[list[int]] = None,
|
||||
mrope_interleaved: bool = False,
|
||||
# YaRN parameters.
|
||||
*,
|
||||
scaling_factor: Optional[float] = None,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
if self.scaling_factor is not None:
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
else:
|
||||
self.mscale = 1.0
|
||||
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
# a larger the cos and sin cache.
|
||||
@ -226,6 +247,16 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
if self.scaling_factor is None:
|
||||
return super()._compute_inv_freq(base)
|
||||
return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
if self.scaling_factor is None:
|
||||
return super()._compute_cos_sin_cache()
|
||||
return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user