mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[XPU][bugfix] fix rope for llama4 and deepseek (#25145)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
parent
48eb8eba58
commit
b798e39f93
@ -14,7 +14,7 @@ from .rocm_aiter_rope_ops import (
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
class RotaryEmbedding(CustomOp):
|
||||
class RotaryEmbeddingBase(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
@ -86,6 +86,21 @@ class RotaryEmbedding(CustomOp):
|
||||
):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .base import RotaryEmbeddingBase
|
||||
from .common import (
|
||||
rotate_gptj,
|
||||
rotate_neox,
|
||||
@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
|
||||
@ -5,10 +5,10 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .base import RotaryEmbeddingBase
|
||||
|
||||
|
||||
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
@ -78,10 +78,3 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(query, key)
|
||||
|
||||
def forward_hip( # type: ignore[override]
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(query, key)
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .base import RotaryEmbeddingBase
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||
|
||||
@ -199,7 +199,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
|
||||
return x_t
|
||||
|
||||
|
||||
class MRotaryEmbedding(RotaryEmbedding):
|
||||
class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
def __init__(
|
||||
@ -357,24 +357,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions(
|
||||
mrope_position_delta: int,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user