mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
Revert "[ROCm][AITER] Support AITER Rope ops in RotaryEmbedding Module." (#22956)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
3d232dbd19
commit
b2f6c247a9
@ -8,7 +8,6 @@ import torch
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
|
||||
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@ -36,7 +35,6 @@ class RotaryEmbedding(CustomOp):
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
@ -121,75 +119,6 @@ class RotaryEmbedding(CustomOp):
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_nope_first=False,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# currently only rotary embedding ops from AITER package are
|
||||
# supported for HiP forward.
|
||||
if self.is_rocm_aiter_enabled:
|
||||
return self.forward_hip_rocm_aiter(positions, query, key, offsets,
|
||||
is_nope_first)
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_hip_rocm_aiter(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
# if is_nope_first
|
||||
# [[batch_size, seq_len, num_heads, nope_size+rope_size]
|
||||
# if NOT is_nope_first
|
||||
# [[batch_size, seq_len, num_heads, rope_size+nope_size],
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_nope_first: bool = False,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
|
||||
|
||||
cos = cos.unsqueeze(-2).unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2).unsqueeze(-2)
|
||||
|
||||
rotate_style = 0 if self.is_neox_style else 1
|
||||
|
||||
num_tokens = positions.numel()
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(1, num_tokens, -1, self.head_size)
|
||||
if key is not None:
|
||||
key_shape = key.shape
|
||||
key = key.view(1, num_tokens, -1, self.head_size)
|
||||
|
||||
positions = positions.view(*query.shape[:2])
|
||||
if offsets is not None:
|
||||
offsets = offsets.view(*query.shape[:2])
|
||||
|
||||
if not is_nope_first:
|
||||
query_ = query[..., :self.rotary_dim]
|
||||
key_ = key[..., :self.rotary_dim] if key is not None else None
|
||||
else:
|
||||
query_ = query[..., -self.rotary_dim:]
|
||||
key_ = key[..., -self.rotary_dim:] if key is not None else None
|
||||
|
||||
if key_ is None:
|
||||
torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip(
|
||||
positions, sin, cos, query_, offsets, rotate_style,
|
||||
is_nope_first)
|
||||
return query.view(query_shape), None
|
||||
|
||||
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip(
|
||||
positions, sin, cos, query_, key_, offsets, rotate_style,
|
||||
is_nope_first)
|
||||
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
||||
return ramp_func
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
def yarn_get_mscale(scale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -9,7 +10,13 @@ from vllm.platforms import current_platform
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
|
||||
yarn_get_mscale, yarn_linear_ramp_mask)
|
||||
yarn_linear_ramp_mask)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
@ -89,9 +96,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
if self.is_rocm_aiter_enabled:
|
||||
return self.forward_hip_rocm_aiter(positions, query, key, offsets)
|
||||
|
||||
assert key is not None
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
|
||||
@ -1,127 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_rotary_embedding_enabled() -> bool:
|
||||
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
import aiter as ops
|
||||
if offsets is None:
|
||||
ops.rope_cached_positions_fwd_inplace(
|
||||
query,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
else:
|
||||
ops.rope_cached_positions_offsets_fwd_inplace(
|
||||
query,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
offsets,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_with_key_forward_hip_impl(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
import aiter as ops
|
||||
if offsets is None:
|
||||
ops.rope_cached_positions_2c_fwd_inplace(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
else:
|
||||
ops.rope_cached_positions_offsets_2c_fwd_inplace(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
offsets,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_with_key_forward_hip_fake(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if is_rocm_rotary_embedding_enabled():
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rotary_emb_with_key_forward_hip",
|
||||
op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl,
|
||||
mutates_args=["key", "query"],
|
||||
fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rotary_emb_without_key_forward_hip",
|
||||
op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl,
|
||||
mutates_args=["query"],
|
||||
fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user