mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:45:49 +08:00
[XPU][bugfix] fix rope for llama4 and deepseek (#25145)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
parent
48eb8eba58
commit
b798e39f93
@ -14,7 +14,7 @@ from .rocm_aiter_rope_ops import (
|
|||||||
|
|
||||||
|
|
||||||
@CustomOp.register("rotary_embedding")
|
@CustomOp.register("rotary_embedding")
|
||||||
class RotaryEmbedding(CustomOp):
|
class RotaryEmbeddingBase(CustomOp):
|
||||||
"""Original rotary positional embedding."""
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -86,6 +86,21 @@ class RotaryEmbedding(CustomOp):
|
|||||||
):
|
):
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(RotaryEmbeddingBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: float,
|
||||||
|
is_neox_style: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||||
|
)
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .base import RotaryEmbedding
|
from .base import RotaryEmbeddingBase
|
||||||
from .common import (
|
from .common import (
|
||||||
rotate_gptj,
|
rotate_gptj,
|
||||||
rotate_neox,
|
rotate_neox,
|
||||||
@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|||||||
return 0.1 * mscale * math.log(scale) + 1.0
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||||
"""RotaryEmbedding extended with YaRN method.
|
"""RotaryEmbedding extended with YaRN method.
|
||||||
|
|
||||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||||
|
|||||||
@ -5,10 +5,10 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .base import RotaryEmbedding
|
from .base import RotaryEmbeddingBase
|
||||||
|
|
||||||
|
|
||||||
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
@ -78,10 +78,3 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
|||||||
key: torch.Tensor | None = None,
|
key: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
return self.forward_native(query, key)
|
return self.forward_native(query, key)
|
||||||
|
|
||||||
def forward_hip( # type: ignore[override]
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor | None = None,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
||||||
return self.forward_native(query, key)
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .base import RotaryEmbedding
|
from .base import RotaryEmbeddingBase
|
||||||
from .common import apply_rotary_emb_dispatch
|
from .common import apply_rotary_emb_dispatch
|
||||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
|
|
||||||
class MRotaryEmbedding(RotaryEmbedding):
|
class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||||
"""Rotary Embedding with Multimodal Sections."""
|
"""Rotary Embedding with Multimodal Sections."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -357,24 +357,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward_xpu(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor | None = None,
|
|
||||||
offsets: torch.Tensor | None = None,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
||||||
return self.forward_native(positions, query, key, offsets)
|
|
||||||
|
|
||||||
def forward_cpu(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor | None = None,
|
|
||||||
offsets: torch.Tensor | None = None,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
||||||
return self.forward_native(positions, query, key, offsets)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_next_input_positions(
|
def get_next_input_positions(
|
||||||
mrope_position_delta: int,
|
mrope_position_delta: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user