[XPU][bugfix] fix rope for llama4 and deepseek (#25145)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma 2025-10-30 09:43:13 +08:00 committed by GitHub
parent 48eb8eba58
commit b798e39f93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 32 deletions

View File

@ -14,7 +14,7 @@ from .rocm_aiter_rope_ops import (
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
class RotaryEmbeddingBase(CustomOp):
"""Original rotary positional embedding."""
def __init__(
@ -86,6 +86,21 @@ class RotaryEmbedding(CustomOp):
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
class RotaryEmbedding(RotaryEmbeddingBase):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def forward_native(
self,
positions: torch.Tensor,

View File

@ -7,7 +7,7 @@ import torch
from vllm.platforms import current_platform
from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
from .common import (
rotate_gptj,
rotate_neox,
@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn

View File

@ -5,10 +5,10 @@ import math
import torch
from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
def __init__(
self,
head_size: int,
@ -78,10 +78,3 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(query, key)
def forward_hip( # type: ignore[override]
self,
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(query, key)

View File

@ -7,7 +7,7 @@ import torch
from vllm.triton_utils import tl, triton
from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
@ -199,7 +199,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
return x_t
class MRotaryEmbedding(RotaryEmbedding):
class MRotaryEmbedding(RotaryEmbeddingBase):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
@ -357,24 +357,6 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,