diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 10fce857a8ae2..6dfc28be7da1a 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -8,6 +8,7 @@ import torch from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch +from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled @CustomOp.register("rotary_embedding") @@ -35,6 +36,7 @@ class RotaryEmbedding(CustomOp): cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) + self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled() def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -119,6 +121,75 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache, self.is_neox_style) return query, key + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + is_nope_first=False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # currently only rotary embedding ops from AITER package are + # supported for HiP forward. + if self.is_rocm_aiter_enabled: + return self.forward_hip_rocm_aiter(positions, query, key, offsets, + is_nope_first) + return self.forward_native(positions, query, key, offsets) + + def forward_hip_rocm_aiter( + self, + positions: torch.Tensor, + # if is_nope_first + # [[batch_size, seq_len, num_heads, nope_size+rope_size] + # if NOT is_nope_first + # [[batch_size, seq_len, num_heads, rope_size+nope_size], + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + is_nope_first: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.cos_sin_cache.device != query.device or \ + self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + cos, sin = self.cos_sin_cache.chunk(2, dim=-1) + + cos = cos.unsqueeze(-2).unsqueeze(-2) + sin = sin.unsqueeze(-2).unsqueeze(-2) + + rotate_style = 0 if self.is_neox_style else 1 + + num_tokens = positions.numel() + + query_shape = query.shape + query = query.view(1, num_tokens, -1, self.head_size) + if key is not None: + key_shape = key.shape + key = key.view(1, num_tokens, -1, self.head_size) + + positions = positions.view(*query.shape[:2]) + if offsets is not None: + offsets = offsets.view(*query.shape[:2]) + + if not is_nope_first: + query_ = query[..., :self.rotary_dim] + key_ = key[..., :self.rotary_dim] if key is not None else None + else: + query_ = query[..., -self.rotary_dim:] + key_ = key[..., -self.rotary_dim:] if key is not None else None + + if key_ is None: + torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip( + positions, sin, cos, query_, offsets, rotate_style, + is_nope_first) + return query.view(query_shape), None + + torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip( + positions, sin, cos, query_, key_, offsets, rotate_style, + is_nope_first) + + return query.view(query_shape), key.view(key_shape) + def forward_xpu( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 8d821bea19e3e..99b6bb2120333 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int, return ramp_func -def yarn_get_mscale(scale: float = 1) -> float: +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: if scale <= 1: return 1.0 - return 0.1 * math.log(scale) + 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index cd888b733426b..5af671703a3f4 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from typing import Optional import torch @@ -10,13 +9,7 @@ from vllm.platforms import current_platform from .base import RotaryEmbedding from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range, - yarn_linear_ramp_mask) - - -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 + yarn_get_mscale, yarn_linear_ramp_mask) class DeepseekScalingRotaryEmbedding(RotaryEmbedding): @@ -96,6 +89,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" + if self.is_rocm_aiter_enabled: + return self.forward_hip_rocm_aiter(positions, query, key, offsets) + assert key is not None query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py new file mode 100644 index 0000000000000..91a2318badb40 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +def is_rocm_rotary_embedding_enabled() -> bool: + return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER) + + +def rocm_aiter_rotary_emb_without_key_forward_hip_impl( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + import aiter as ops + if offsets is None: + ops.rope_cached_positions_fwd_inplace( + query, + cos, + sin, + positions, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + else: + ops.rope_cached_positions_offsets_fwd_inplace( + query, + cos, + sin, + positions, + offsets, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + + +def rocm_aiter_rotary_emb_with_key_forward_hip_impl( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + import aiter as ops + if offsets is None: + ops.rope_cached_positions_2c_fwd_inplace( + query, + key, + cos, + sin, + positions, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + else: + ops.rope_cached_positions_offsets_2c_fwd_inplace( + query, + key, + cos, + sin, + positions, + offsets, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + + +def rocm_aiter_rotary_emb_with_key_forward_hip_fake( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + pass + + +def rocm_aiter_rotary_emb_without_key_forward_hip_fake( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + pass + + +if is_rocm_rotary_embedding_enabled(): + + direct_register_custom_op( + op_name="rocm_aiter_rotary_emb_with_key_forward_hip", + op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl, + mutates_args=["key", "query"], + fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rotary_emb_without_key_forward_hip", + op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl, + mutates_args=["query"], + fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake, + dispatch_key=current_platform.dispatch_key, + ) \ No newline at end of file