mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 11:45:39 +08:00
[feat] Support MRoPE + YaRN (#25384)
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
This commit is contained in:
parent
4741239db7
commit
fc97733da8
@ -153,11 +153,23 @@ def get_rope(
|
|||||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||||
"beta_slow")
|
"beta_slow")
|
||||||
}
|
}
|
||||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
if "mrope_section" in rope_scaling:
|
||||||
|
rotary_emb = MRotaryEmbedding(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
original_max_position,
|
original_max_position,
|
||||||
base, is_neox_style,
|
base,
|
||||||
scaling_factor, dtype,
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
mrope_section=rope_scaling["mrope_section"],
|
||||||
|
mrope_interleaved=rope_scaling.get("mrope_interleaved",
|
||||||
|
False),
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
|
else:
|
||||||
|
rotary_emb = YaRNScalingRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, original_max_position, base,
|
||||||
|
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||||
elif scaling_type == "deepseek_yarn":
|
elif scaling_type == "deepseek_yarn":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
original_max_position = rope_scaling[
|
original_max_position = rope_scaling[
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.triton_utils import tl, triton
|
|||||||
|
|
||||||
from .base import RotaryEmbedding
|
from .base import RotaryEmbedding
|
||||||
from .common import apply_rotary_emb_dispatch
|
from .common import apply_rotary_emb_dispatch
|
||||||
|
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -213,7 +214,27 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
mrope_section: Optional[list[int]] = None,
|
mrope_section: Optional[list[int]] = None,
|
||||||
mrope_interleaved: bool = False,
|
mrope_interleaved: bool = False,
|
||||||
|
# YaRN parameters.
|
||||||
|
*,
|
||||||
|
scaling_factor: Optional[float] = None,
|
||||||
|
extrapolation_factor: float = 1,
|
||||||
|
attn_factor: float = 1,
|
||||||
|
beta_fast: int = 32,
|
||||||
|
beta_slow: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.extrapolation_factor = extrapolation_factor
|
||||||
|
self.attn_factor = attn_factor
|
||||||
|
self.beta_fast = beta_fast
|
||||||
|
self.beta_slow = beta_slow
|
||||||
|
if self.scaling_factor is not None:
|
||||||
|
# Get n-d magnitude scaling corrected for interpolation
|
||||||
|
self.mscale = float(
|
||||||
|
yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
|
else:
|
||||||
|
self.mscale = 1.0
|
||||||
|
|
||||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||||
# a larger the cos and sin cache.
|
# a larger the cos and sin cache.
|
||||||
@ -226,6 +247,16 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
if self.mrope_section:
|
if self.mrope_section:
|
||||||
assert sum(self.mrope_section) == rotary_dim // 2
|
assert sum(self.mrope_section) == rotary_dim // 2
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
|
if self.scaling_factor is None:
|
||||||
|
return super()._compute_inv_freq(base)
|
||||||
|
return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
if self.scaling_factor is None:
|
||||||
|
return super()._compute_cos_sin_cache()
|
||||||
|
return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user