From b2f6c247a9b84556a8ea0e75bb4a2db765ff3315 Mon Sep 17 00:00:00 2001 From: TJian Date: Thu, 14 Aug 2025 23:39:19 -0700 Subject: [PATCH] Revert "[ROCm][AITER] Support AITER Rope ops in RotaryEmbedding Module." (#22956) Signed-off-by: vllmellm Co-authored-by: vllmellm --- .../layers/rotary_embedding/base.py | 71 ---------- .../layers/rotary_embedding/common.py | 4 +- .../rotary_embedding/deepseek_scaling_rope.py | 12 +- .../rotary_embedding/rocm_aiter_rope_ops.py | 127 ------------------ 4 files changed, 10 insertions(+), 204 deletions(-) delete mode 100644 vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 6dfc28be7da1..10fce857a8ae 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -8,7 +8,6 @@ import torch from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch -from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled @CustomOp.register("rotary_embedding") @@ -36,7 +35,6 @@ class RotaryEmbedding(CustomOp): cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled() def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -121,75 +119,6 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache, self.is_neox_style) return query, key - def forward_hip( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - is_nope_first=False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - # currently only rotary embedding ops from AITER package are - # supported for HiP forward. - if self.is_rocm_aiter_enabled: - return self.forward_hip_rocm_aiter(positions, query, key, offsets, - is_nope_first) - return self.forward_native(positions, query, key, offsets) - - def forward_hip_rocm_aiter( - self, - positions: torch.Tensor, - # if is_nope_first - # [[batch_size, seq_len, num_heads, nope_size+rope_size] - # if NOT is_nope_first - # [[batch_size, seq_len, num_heads, rope_size+nope_size], - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - is_nope_first: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.cos_sin_cache.device != query.device or \ - self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - cos, sin = self.cos_sin_cache.chunk(2, dim=-1) - - cos = cos.unsqueeze(-2).unsqueeze(-2) - sin = sin.unsqueeze(-2).unsqueeze(-2) - - rotate_style = 0 if self.is_neox_style else 1 - - num_tokens = positions.numel() - - query_shape = query.shape - query = query.view(1, num_tokens, -1, self.head_size) - if key is not None: - key_shape = key.shape - key = key.view(1, num_tokens, -1, self.head_size) - - positions = positions.view(*query.shape[:2]) - if offsets is not None: - offsets = offsets.view(*query.shape[:2]) - - if not is_nope_first: - query_ = query[..., :self.rotary_dim] - key_ = key[..., :self.rotary_dim] if key is not None else None - else: - query_ = query[..., -self.rotary_dim:] - key_ = key[..., -self.rotary_dim:] if key is not None else None - - if key_ is None: - torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip( - positions, sin, cos, query_, offsets, rotate_style, - is_nope_first) - return query.view(query_shape), None - - torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip( - positions, sin, cos, query_, key_, offsets, rotate_style, - is_nope_first) - - return query.view(query_shape), key.view(key_shape) - def forward_xpu( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 99b6bb212033..8d821bea19e3 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int, return ramp_func -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: +def yarn_get_mscale(scale: float = 1) -> float: if scale <= 1: return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 + return 0.1 * math.log(scale) + 1.0 diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 5af671703a3f..cd888b733426 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from typing import Optional import torch @@ -9,7 +10,13 @@ from vllm.platforms import current_platform from .base import RotaryEmbedding from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range, - yarn_get_mscale, yarn_linear_ramp_mask) + yarn_linear_ramp_mask) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekScalingRotaryEmbedding(RotaryEmbedding): @@ -89,9 +96,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" - if self.is_rocm_aiter_enabled: - return self.forward_hip_rocm_aiter(positions, query, key, offsets) - assert key is not None query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py deleted file mode 100644 index 91a2318badb4..000000000000 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import torch - -import vllm.envs as envs -from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op - - -def is_rocm_rotary_embedding_enabled() -> bool: - return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER) - - -def rocm_aiter_rotary_emb_without_key_forward_hip_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter as ops - if offsets is None: - ops.rope_cached_positions_fwd_inplace( - query, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - else: - ops.rope_cached_positions_offsets_fwd_inplace( - query, - cos, - sin, - positions, - offsets, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_hip_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter as ops - if offsets is None: - ops.rope_cached_positions_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - else: - ops.rope_cached_positions_offsets_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - offsets, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_hip_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -def rocm_aiter_rotary_emb_without_key_forward_hip_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -if is_rocm_rotary_embedding_enabled(): - - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_with_key_forward_hip", - op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl, - mutates_args=["key", "query"], - fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake, - dispatch_key=current_platform.dispatch_key, - ) - - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_without_key_forward_hip", - op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl, - mutates_args=["query"], - fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake, - dispatch_key=current_platform.dispatch_key, - ) \ No newline at end of file