mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:34:58 +08:00
Revert "[ROCm][AITER] Support AITER Rope ops in RotaryEmbedding Module." (#22956)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
3d232dbd19
commit
b2f6c247a9
@ -8,7 +8,6 @@ import torch
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
|
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
|
||||||
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled
|
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("rotary_embedding")
|
@CustomOp.register("rotary_embedding")
|
||||||
@ -36,7 +35,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
cache = cache.to(dtype)
|
cache = cache.to(dtype)
|
||||||
self.cos_sin_cache: torch.Tensor
|
self.cos_sin_cache: torch.Tensor
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()
|
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
@ -121,75 +119,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.cos_sin_cache, self.is_neox_style)
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward_hip(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: Optional[torch.Tensor] = None,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
is_nope_first=False,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
# currently only rotary embedding ops from AITER package are
|
|
||||||
# supported for HiP forward.
|
|
||||||
if self.is_rocm_aiter_enabled:
|
|
||||||
return self.forward_hip_rocm_aiter(positions, query, key, offsets,
|
|
||||||
is_nope_first)
|
|
||||||
return self.forward_native(positions, query, key, offsets)
|
|
||||||
|
|
||||||
def forward_hip_rocm_aiter(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
# if is_nope_first
|
|
||||||
# [[batch_size, seq_len, num_heads, nope_size+rope_size]
|
|
||||||
# if NOT is_nope_first
|
|
||||||
# [[batch_size, seq_len, num_heads, rope_size+nope_size],
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: Optional[torch.Tensor] = None,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
is_nope_first: bool = False,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
if self.cos_sin_cache.device != query.device or \
|
|
||||||
self.cos_sin_cache.dtype != query.dtype:
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
|
||||||
dtype=query.dtype)
|
|
||||||
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
|
|
||||||
|
|
||||||
cos = cos.unsqueeze(-2).unsqueeze(-2)
|
|
||||||
sin = sin.unsqueeze(-2).unsqueeze(-2)
|
|
||||||
|
|
||||||
rotate_style = 0 if self.is_neox_style else 1
|
|
||||||
|
|
||||||
num_tokens = positions.numel()
|
|
||||||
|
|
||||||
query_shape = query.shape
|
|
||||||
query = query.view(1, num_tokens, -1, self.head_size)
|
|
||||||
if key is not None:
|
|
||||||
key_shape = key.shape
|
|
||||||
key = key.view(1, num_tokens, -1, self.head_size)
|
|
||||||
|
|
||||||
positions = positions.view(*query.shape[:2])
|
|
||||||
if offsets is not None:
|
|
||||||
offsets = offsets.view(*query.shape[:2])
|
|
||||||
|
|
||||||
if not is_nope_first:
|
|
||||||
query_ = query[..., :self.rotary_dim]
|
|
||||||
key_ = key[..., :self.rotary_dim] if key is not None else None
|
|
||||||
else:
|
|
||||||
query_ = query[..., -self.rotary_dim:]
|
|
||||||
key_ = key[..., -self.rotary_dim:] if key is not None else None
|
|
||||||
|
|
||||||
if key_ is None:
|
|
||||||
torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip(
|
|
||||||
positions, sin, cos, query_, offsets, rotate_style,
|
|
||||||
is_nope_first)
|
|
||||||
return query.view(query_shape), None
|
|
||||||
|
|
||||||
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip(
|
|
||||||
positions, sin, cos, query_, key_, offsets, rotate_style,
|
|
||||||
is_nope_first)
|
|
||||||
|
|
||||||
return query.view(query_shape), key.view(key_shape)
|
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
|||||||
return ramp_func
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
def yarn_get_mscale(scale: float = 1) -> float:
|
||||||
if scale <= 1:
|
if scale <= 1:
|
||||||
return 1.0
|
return 1.0
|
||||||
return 0.1 * mscale * math.log(scale) + 1.0
|
return 0.1 * math.log(scale) + 1.0
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,7 +10,13 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
from .base import RotaryEmbedding
|
from .base import RotaryEmbedding
|
||||||
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
|
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
|
||||||
yarn_get_mscale, yarn_linear_ramp_mask)
|
yarn_linear_ramp_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
@ -89,9 +96,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
if self.is_rocm_aiter_enabled:
|
|
||||||
return self.forward_hip_rocm_aiter(positions, query, key, offsets)
|
|
||||||
|
|
||||||
assert key is not None
|
assert key is not None
|
||||||
query_rot = query[..., :self.rotary_dim]
|
query_rot = query[..., :self.rotary_dim]
|
||||||
key_rot = key[..., :self.rotary_dim]
|
key_rot = key[..., :self.rotary_dim]
|
||||||
|
|||||||
@ -1,127 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
|
|
||||||
def is_rocm_rotary_embedding_enabled() -> bool:
|
|
||||||
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
|
|
||||||
positions: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
rotate_style: int = 0,
|
|
||||||
is_nope_first: bool = False,
|
|
||||||
) -> None:
|
|
||||||
import aiter as ops
|
|
||||||
if offsets is None:
|
|
||||||
ops.rope_cached_positions_fwd_inplace(
|
|
||||||
query,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
positions,
|
|
||||||
rotate_style,
|
|
||||||
reuse_freqs_front_part=True,
|
|
||||||
nope_first=is_nope_first,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ops.rope_cached_positions_offsets_fwd_inplace(
|
|
||||||
query,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
positions,
|
|
||||||
offsets,
|
|
||||||
rotate_style,
|
|
||||||
reuse_freqs_front_part=True,
|
|
||||||
nope_first=is_nope_first,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_rotary_emb_with_key_forward_hip_impl(
|
|
||||||
positions: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
rotate_style: int = 0,
|
|
||||||
is_nope_first: bool = False,
|
|
||||||
) -> None:
|
|
||||||
import aiter as ops
|
|
||||||
if offsets is None:
|
|
||||||
ops.rope_cached_positions_2c_fwd_inplace(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
positions,
|
|
||||||
rotate_style,
|
|
||||||
reuse_freqs_front_part=True,
|
|
||||||
nope_first=is_nope_first,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ops.rope_cached_positions_offsets_2c_fwd_inplace(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
positions,
|
|
||||||
offsets,
|
|
||||||
rotate_style,
|
|
||||||
reuse_freqs_front_part=True,
|
|
||||||
nope_first=is_nope_first,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_rotary_emb_with_key_forward_hip_fake(
|
|
||||||
positions: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
rotate_style: int = 0,
|
|
||||||
is_nope_first: bool = False,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
|
|
||||||
positions: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
rotate_style: int = 0,
|
|
||||||
is_nope_first: bool = False,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if is_rocm_rotary_embedding_enabled():
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_aiter_rotary_emb_with_key_forward_hip",
|
|
||||||
op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl,
|
|
||||||
mutates_args=["key", "query"],
|
|
||||||
fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_aiter_rotary_emb_without_key_forward_hip",
|
|
||||||
op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl,
|
|
||||||
mutates_args=["query"],
|
|
||||||
fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user