diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index c1259a1b11ea5..b94c82e132583 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, from vllm import attention_ops from vllm import cache_ops from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding, YaRNScalingRotaryEmbedding) +from vllm.model_executor.layers.rotary_embedding import get_rope _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention): scale, num_kv_heads, sliding_window=sliding_window) - if rope_scaling is None: - self.rotary_emb = RotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style) - else: - scaling_type = rope_scaling["type"] - scaling_factor = rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LinearScalingRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) - elif scaling_type == "dynamic": - self.rotary_emb = DynamicNTKScalingRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) - elif scaling_type == "yarn": - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - assert max_position == original_max_position * scaling_factor - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", - "beta_fast", "beta_slow") - } - self.rotary_emb = YaRNScalingRotaryEmbedding( - head_size, rotary_dim, original_max_position, base, - is_neox_style, scaling_factor, **extra_kwargs) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, rope_scaling) def forward( self, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 2cbd3b584c06e..1b88e9a3b8057 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -22,7 +22,7 @@ # limitations under the License. """Rotary Positional Embeddings.""" import math -from typing import Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): sin = (freqs.sin() * self.mscale) cache = torch.cat((cos, sin), dim=-1) return cache + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool, + rope_scaling: Optional[Dict[str, Any]], +) -> RotaryEmbedding: + if rope_scaling is None: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style) + else: + scaling_type = rope_scaling["type"] + scaling_factor = rope_scaling["factor"] + if scaling_type == "linear": + rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor) + elif scaling_type == "dynamic": + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor) + elif scaling_type == "yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, + **extra_kwargs) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + return rotary_emb