Revert "[ROCm][AITER] Support AITER Rope ops in RotaryEmbedding Module." (#22956)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
TJian 2025-08-14 23:39:19 -07:00 committed by GitHub
parent 3d232dbd19
commit b2f6c247a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 10 additions and 204 deletions

View File

@ -8,7 +8,6 @@ import torch
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled
@CustomOp.register("rotary_embedding")
@ -36,7 +35,6 @@ class RotaryEmbedding(CustomOp):
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
@ -121,75 +119,6 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
is_nope_first=False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# currently only rotary embedding ops from AITER package are
# supported for HiP forward.
if self.is_rocm_aiter_enabled:
return self.forward_hip_rocm_aiter(positions, query, key, offsets,
is_nope_first)
return self.forward_native(positions, query, key, offsets)
def forward_hip_rocm_aiter(
self,
positions: torch.Tensor,
# if is_nope_first
# [[batch_size, seq_len, num_heads, nope_size+rope_size]
# if NOT is_nope_first
# [[batch_size, seq_len, num_heads, rope_size+nope_size],
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
is_nope_first: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)
rotate_style = 0 if self.is_neox_style else 1
num_tokens = positions.numel()
query_shape = query.shape
query = query.view(1, num_tokens, -1, self.head_size)
if key is not None:
key_shape = key.shape
key = key.view(1, num_tokens, -1, self.head_size)
positions = positions.view(*query.shape[:2])
if offsets is not None:
offsets = offsets.view(*query.shape[:2])
if not is_nope_first:
query_ = query[..., :self.rotary_dim]
key_ = key[..., :self.rotary_dim] if key is not None else None
else:
query_ = query[..., -self.rotary_dim:]
key_ = key[..., -self.rotary_dim:] if key is not None else None
if key_ is None:
torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip(
positions, sin, cos, query_, offsets, rotate_style,
is_nope_first)
return query.view(query_shape), None
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip(
positions, sin, cos, query_, key_, offsets, rotate_style,
is_nope_first)
return query.view(query_shape), key.view(key_shape)
def forward_xpu(
self,
positions: torch.Tensor,

View File

@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
return ramp_func
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
def yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
return 0.1 * math.log(scale) + 1.0

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
@ -9,7 +10,13 @@ from vllm.platforms import current_platform
from .base import RotaryEmbedding
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
yarn_get_mscale, yarn_linear_ramp_mask)
yarn_linear_ramp_mask)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
@ -89,9 +96,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
if self.is_rocm_aiter_enabled:
return self.forward_hip_rocm_aiter(positions, query, key, offsets)
assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]

View File

@ -1,127 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_rotary_embedding_enabled() -> bool:
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)
def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter as ops
if offsets is None:
ops.rope_cached_positions_fwd_inplace(
query,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_fwd_inplace(
query,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
def rocm_aiter_rotary_emb_with_key_forward_hip_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter as ops
if offsets is None:
ops.rope_cached_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
def rocm_aiter_rotary_emb_with_key_forward_hip_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass
def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass
if is_rocm_rotary_embedding_enabled():
direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_hip",
op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_without_key_forward_hip",
op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl,
mutates_args=["query"],
fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
dispatch_key=current_platform.dispatch_key,
)