[feat] Support MRoPE + YaRN (#25384)

Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com>
Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
This commit is contained in:
JJJYmmm 2025-09-23 11:04:47 +08:00 committed by GitHub
parent 4741239db7
commit fc97733da8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 5 deletions

View File

@ -153,11 +153,23 @@ def get_rope(
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved",
False),
scaling_factor=scaling_factor,
**extra_kwargs)
else:
rotary_emb = YaRNScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[

View File

@ -12,6 +12,7 @@ from vllm.triton_utils import tl, triton
from .base import RotaryEmbedding
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
@triton.jit
@ -213,7 +214,27 @@ class MRotaryEmbedding(RotaryEmbedding):
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
mrope_interleaved: bool = False,
# YaRN parameters.
*,
scaling_factor: Optional[float] = None,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
yarn_get_mscale(self.scaling_factor) * attn_factor)
else:
self.mscale = 1.0
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
@ -226,6 +247,16 @@ class MRotaryEmbedding(RotaryEmbedding):
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def _compute_inv_freq(self, base: float) -> torch.Tensor:
if self.scaling_factor is None:
return super()._compute_inv_freq(base)
return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)
def _compute_cos_sin_cache(self) -> torch.Tensor:
if self.scaling_factor is None:
return super()._compute_cos_sin_cache()
return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)
def forward_native(
self,
positions: torch.Tensor,