diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 711902f0cc67..91276320df4d 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -14,7 +14,7 @@ from .rocm_aiter_rope_ops import ( @CustomOp.register("rotary_embedding") -class RotaryEmbedding(CustomOp): +class RotaryEmbeddingBase(CustomOp): """Original rotary positional embedding.""" def __init__( @@ -86,6 +86,21 @@ class RotaryEmbedding(CustomOp): ): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + +class RotaryEmbedding(RotaryEmbeddingBase): + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + def forward_native( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 2e5efec06663..d9134f05fddf 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -7,7 +7,7 @@ import torch from vllm.platforms import current_platform -from .base import RotaryEmbedding +from .base import RotaryEmbeddingBase from .common import ( rotate_gptj, rotate_neox, @@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 -class DeepseekScalingRotaryEmbedding(RotaryEmbedding): +class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): """RotaryEmbedding extended with YaRN method. Credits to Peng et al. github.com/jquesnelle/yarn diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 6241cb5abbc8..9fdac309df7e 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -5,10 +5,10 @@ import math import torch -from .base import RotaryEmbedding +from .base import RotaryEmbeddingBase -class Llama4VisionRotaryEmbedding(RotaryEmbedding): +class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase): def __init__( self, head_size: int, @@ -78,10 +78,3 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: return self.forward_native(query, key) - - def forward_hip( # type: ignore[override] - self, - query: torch.Tensor, - key: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - return self.forward_native(query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index d269733083d8..3c184ce9d631 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -7,7 +7,7 @@ import torch from vllm.triton_utils import tl, triton -from .base import RotaryEmbedding +from .base import RotaryEmbeddingBase from .common import apply_rotary_emb_dispatch from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale @@ -199,7 +199,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T return x_t -class MRotaryEmbedding(RotaryEmbedding): +class MRotaryEmbedding(RotaryEmbeddingBase): """Rotary Embedding with Multimodal Sections.""" def __init__( @@ -357,24 +357,6 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key - def forward_xpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None = None, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - return self.forward_native(positions, query, key, offsets) - - def forward_cpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None = None, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - return self.forward_native(positions, query, key, offsets) - @staticmethod def get_next_input_positions( mrope_position_delta: int,