mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:16:00 +08:00
943 lines
28 KiB
Python
943 lines
28 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import functools
|
|
from collections.abc import Callable
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
|
|
|
|
|
def is_aiter_found() -> bool:
|
|
from importlib.util import find_spec
|
|
|
|
return find_spec("aiter") is not None
|
|
|
|
|
|
# `find_spec` is not torch.compile compatible.
|
|
# In cases where aiter availability might have
|
|
# been checked in forward passes that are torch compiled.
|
|
# we keep this global outside to not cause torch compile breaks.
|
|
IS_AITER_FOUND = is_aiter_found()
|
|
|
|
|
|
def if_aiter_supported(func: Callable) -> Callable:
|
|
"""Decorator that only executes the function if
|
|
ROCm AITER package is supported on gfx9 archs.
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
# checks the platform, device arch and aiter library existance.
|
|
|
|
from vllm.platforms.rocm import on_gfx9
|
|
|
|
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
|
|
return func(*args, **kwargs)
|
|
else:
|
|
# Return None or do nothing if not supported
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
|
|
def _rocm_aiter_fused_moe_impl(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
quant_method: int = 0,
|
|
doweight_stage1: bool = False,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
from aiter import ActivationType, QuantType
|
|
from aiter.fused_moe import fused_moe
|
|
|
|
activation = ActivationType(activation_method)
|
|
quant_type = QuantType(quant_method)
|
|
|
|
return fused_moe(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
expert_mask,
|
|
activation,
|
|
quant_type,
|
|
doweight_stage1,
|
|
w1_scale,
|
|
w2_scale,
|
|
a1_scale,
|
|
a2_scale,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_fused_moe_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
quant_method: int = 0,
|
|
doweight_stage1: bool = False,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(hidden_states)
|
|
|
|
|
|
def _rocm_aiter_asm_moe_tkw1_impl(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
fc1_scale: torch.Tensor | None = None,
|
|
fc2_scale: torch.Tensor | None = None,
|
|
fc1_smooth_scale: torch.Tensor | None = None,
|
|
fc2_smooth_scale: torch.Tensor | None = None,
|
|
a16: bool = False,
|
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
) -> torch.Tensor:
|
|
from aiter import ActivationType
|
|
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
|
|
|
activation = ActivationType(activation_method)
|
|
|
|
return asm_moe_tkw1(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
fc1_scale=fc1_scale,
|
|
fc2_scale=fc2_scale,
|
|
fc1_smooth_scale=fc1_smooth_scale,
|
|
fc2_smooth_scale=fc2_smooth_scale,
|
|
a16=a16,
|
|
per_tensor_quant_scale=per_tensor_quant_scale,
|
|
expert_mask=expert_mask,
|
|
activation=activation,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_asm_moe_tkw1_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
fc1_scale: torch.Tensor | None = None,
|
|
fc2_scale: torch.Tensor | None = None,
|
|
fc1_smooth_scale: torch.Tensor | None = None,
|
|
fc2_smooth_scale: torch.Tensor | None = None,
|
|
a16: bool = False,
|
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(hidden_states)
|
|
|
|
|
|
def _rocm_aiter_topk_softmax_impl(
|
|
topk_weights: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool,
|
|
) -> None:
|
|
from aiter import topk_softmax
|
|
|
|
topk_softmax(
|
|
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
|
)
|
|
|
|
|
|
def _rocm_aiter_topk_softmax_fake(
|
|
topk_weights: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def _rocm_aiter_biased_grouped_topk_impl(
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
from aiter import biased_grouped_topk
|
|
|
|
biased_grouped_topk(
|
|
gating_output,
|
|
correction_bias,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_biased_grouped_topk_fake(
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def _rocm_aiter_grouped_topk_impl(
|
|
gating_output: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
is_softmax = scoring_func == "softmax"
|
|
from aiter import grouped_topk
|
|
|
|
grouped_topk(
|
|
gating_output,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
is_softmax,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_grouped_topk_fake(
|
|
gating_output: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def _rocm_aiter_mla_decode_fwd_impl(
|
|
q: torch.Tensor,
|
|
kv_buffer: torch.Tensor,
|
|
o: torch.Tensor,
|
|
qo_indptr: torch.Tensor,
|
|
max_seqlen_qo: int,
|
|
kv_indptr: torch.Tensor | None = None,
|
|
kv_indices: torch.Tensor | None = None,
|
|
kv_last_page_lens: torch.Tensor | None = None,
|
|
sm_scale: float = 1.0,
|
|
logit_cap: float = 0.0,
|
|
) -> None:
|
|
from aiter.mla import mla_decode_fwd
|
|
|
|
mla_decode_fwd(
|
|
q,
|
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
|
o,
|
|
qo_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_lens,
|
|
max_seqlen_qo,
|
|
sm_scale=sm_scale,
|
|
logit_cap=logit_cap,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_mla_decode_fwd_fake(
|
|
q: torch.Tensor,
|
|
kv_buffer: torch.Tensor,
|
|
o: torch.Tensor,
|
|
qo_indptr: torch.Tensor,
|
|
max_seqlen_qo: int,
|
|
kv_indptr: torch.Tensor | None = None,
|
|
kv_indices: torch.Tensor | None = None,
|
|
kv_last_page_lens: torch.Tensor | None = None,
|
|
sm_scale: float = 1.0,
|
|
logit_cap: float = 0.0,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def _rocm_aiter_gemm_w8a8_impl(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
from aiter import gemm_a8w8_CK
|
|
|
|
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
|
# a to be [M, K]
|
|
# b to be [N, K]
|
|
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
|
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
|
|
|
|
|
def _rocm_aiter_gemm_w8a8_fake(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
m = A.shape[0]
|
|
n = B.shape[0]
|
|
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
|
return Y
|
|
|
|
|
|
def _rocm_aiter_gemm_w8a8_blockscale_impl(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
from aiter import gemm_a8w8_blockscale
|
|
|
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
|
|
|
|
|
def _rocm_aiter_gemm_w8a8_blockscale_fake(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
m = A.shape[0]
|
|
n = B.shape[0]
|
|
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
|
return Y
|
|
|
|
|
|
def _rocm_aiter_rms_norm_impl(
|
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
) -> torch.Tensor:
|
|
from aiter import rms_norm
|
|
|
|
if x.dim() > 2:
|
|
x_original_shape = x.shape
|
|
x = x.reshape(-1, x_original_shape[-1])
|
|
x = rms_norm(x, weight, variance_epsilon)
|
|
return x.reshape(x_original_shape)
|
|
|
|
return rms_norm(x, weight, variance_epsilon)
|
|
|
|
|
|
def _rocm_aiter_rms_norm_fake(
|
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(x)
|
|
|
|
|
|
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from aiter import rmsnorm2d_fwd_with_add
|
|
|
|
residual_out = torch.empty_like(residual)
|
|
output = torch.empty_like(x)
|
|
rmsnorm2d_fwd_with_add(
|
|
output, # output
|
|
x, # input
|
|
residual, # residual input
|
|
residual_out, # residual output
|
|
weight,
|
|
variance_epsilon,
|
|
)
|
|
return output, residual_out
|
|
|
|
|
|
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.empty_like(x), torch.empty_like(residual)
|
|
|
|
|
|
# Global flag to ensure ops are registered only once
|
|
_OPS_REGISTERED = False
|
|
|
|
|
|
class rocm_aiter_ops:
|
|
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
|
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
|
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
|
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
|
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
|
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
|
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
|
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
|
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
|
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
|
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
|
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_enabled(cls) -> bool:
|
|
"""Verifies device specs and availability of aiter main env variable."""
|
|
return cls._AITER_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_linear_enabled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_linear_fp8_enaled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_rmsnorm_enabled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_fused_moe_enabled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
|
|
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_mla_enabled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_mha_enabled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_pa_attn_enabled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_triton_unified_attn_enabled(cls) -> bool:
|
|
""" "Verifies device specs and availability of env variable."""
|
|
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_fp8bmm_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_triton_rotary_embed_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
|
|
|
@staticmethod
|
|
@if_aiter_supported
|
|
def register_ops_once() -> None:
|
|
global _OPS_REGISTERED
|
|
if not _OPS_REGISTERED:
|
|
tags = (
|
|
tuple()
|
|
if is_torch_equal_or_newer("2.7.0")
|
|
else (torch.Tag.needs_fixed_stride_order,)
|
|
)
|
|
|
|
# register all the custom ops here
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_asm_moe_tkw1",
|
|
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_fused_moe",
|
|
op_func=_rocm_aiter_fused_moe_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_fused_moe_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_topk_softmax",
|
|
op_func=_rocm_aiter_topk_softmax_impl,
|
|
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
|
fake_impl=_rocm_aiter_topk_softmax_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_biased_grouped_topk",
|
|
op_func=_rocm_aiter_biased_grouped_topk_impl,
|
|
mutates_args=["topk_weights", "topk_ids"],
|
|
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_grouped_topk",
|
|
op_func=_rocm_aiter_grouped_topk_impl,
|
|
mutates_args=["topk_weights", "topk_ids"],
|
|
fake_impl=_rocm_aiter_grouped_topk_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_mla_decode_fwd",
|
|
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
|
mutates_args=["o"],
|
|
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
|
|
tags=tags,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_gemm_w8a8",
|
|
op_func=_rocm_aiter_gemm_w8a8_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_gemm_w8a8_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
|
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_rms_norm",
|
|
op_func=_rocm_aiter_rms_norm_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_rms_norm_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
|
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
_OPS_REGISTERED = True
|
|
|
|
@staticmethod
|
|
def rms_norm2d_with_add(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
|
|
x, residual, weight, variance_epsilon
|
|
)
|
|
|
|
@staticmethod
|
|
def rms_norm(
|
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
|
|
|
@staticmethod
|
|
def gemm_w8a8(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
|
|
|
|
@staticmethod
|
|
def gemm_w8a8_blockscale(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: list[int],
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
|
A, B, As, Bs, output_dtype
|
|
)
|
|
|
|
@staticmethod
|
|
def fused_moe(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
quant_method: int = 0,
|
|
doweight_stage1: bool = False,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_fused_moe(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
expert_mask,
|
|
activation_method,
|
|
quant_method,
|
|
doweight_stage1,
|
|
w1_scale,
|
|
w2_scale,
|
|
a1_scale,
|
|
a2_scale,
|
|
)
|
|
|
|
@staticmethod
|
|
def asm_moe_tkw1(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
fc1_scale: torch.Tensor | None = None,
|
|
fc2_scale: torch.Tensor | None = None,
|
|
fc1_smooth_scale: torch.Tensor | None = None,
|
|
fc2_smooth_scale: torch.Tensor | None = None,
|
|
a16: bool = False,
|
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
fc1_scale,
|
|
fc2_scale,
|
|
fc1_smooth_scale,
|
|
fc2_smooth_scale,
|
|
a16,
|
|
per_tensor_quant_scale,
|
|
expert_mask,
|
|
activation_method,
|
|
)
|
|
|
|
@staticmethod
|
|
def topk_softmax(
|
|
topk_weights: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
torch.ops.vllm.rocm_aiter_topk_softmax(
|
|
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
|
)
|
|
return topk_weights, topk_indices
|
|
|
|
@staticmethod
|
|
def biased_grouped_topk(
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
routed_scaling_factor: float = 1.0,
|
|
) -> None:
|
|
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
|
gating_output,
|
|
correction_bias,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
@staticmethod
|
|
def grouped_topk(
|
|
gating_output: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
) -> None:
|
|
torch.ops.vllm.rocm_aiter_grouped_topk(
|
|
gating_output,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
scoring_func,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
@staticmethod
|
|
def mla_decode_fwd(
|
|
q: torch.Tensor,
|
|
kv_buffer: torch.Tensor,
|
|
o: torch.Tensor,
|
|
sm_scale: float,
|
|
qo_indptr: torch.Tensor,
|
|
max_seqlen_qo: int,
|
|
kv_indptr: torch.Tensor | None = None,
|
|
kv_indices: torch.Tensor | None = None,
|
|
kv_last_page_lens: torch.Tensor | None = None,
|
|
logit_cap: float = 0.0,
|
|
):
|
|
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
|
q,
|
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
|
o,
|
|
qo_indptr,
|
|
max_seqlen_qo,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_lens,
|
|
sm_scale=sm_scale,
|
|
logit_cap=logit_cap,
|
|
)
|
|
|
|
@staticmethod
|
|
def triton_fp4_gemm_dynamic_qaunt(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
out_dtype: torch.dtype | None = torch.bfloat16,
|
|
x_scales: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
|
|
|
if x_scales is None:
|
|
x_q, x_s = dynamic_mxfp4_quant(x)
|
|
else:
|
|
x_q = x
|
|
x_s = x_scales
|
|
|
|
y = torch.empty(
|
|
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
|
|
)
|
|
|
|
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def triton_rotary_embed(
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
head_size: int,
|
|
rotary_dim: int,
|
|
is_neox_style: bool,
|
|
):
|
|
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
|
|
|
|
num_tokens = positions.numel()
|
|
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
|
query_shape = query.shape
|
|
key_shape = key.shape
|
|
rotate_style = 0 if is_neox_style else 1
|
|
|
|
query = query.view(num_tokens, -1, head_size)
|
|
key = key.view(num_tokens, -1, head_size)
|
|
query_ = query[..., :rotary_dim]
|
|
key_ = key[..., :rotary_dim]
|
|
positions = positions.view(*query.shape[:1])
|
|
rope_cached_thd_positions_2c_fwd_inplace(
|
|
positions,
|
|
sin,
|
|
cos,
|
|
query_,
|
|
key_,
|
|
rotate_style,
|
|
reuse_freqs_front_part=True,
|
|
is_nope_first=False,
|
|
)
|
|
query = query.view(query_shape)
|
|
key = key.view(key_shape)
|
|
|
|
@staticmethod
|
|
def triton_fp8_bmm(
|
|
X: torch.Tensor,
|
|
WQ: torch.Tensor,
|
|
w_scale: torch.Tensor,
|
|
group_size: int = 128,
|
|
bias: torch.Tensor | None = None,
|
|
dtype: torch.dtype | None = torch.bfloat16,
|
|
splitK: int | None = None,
|
|
YQ: torch.Tensor | None = None,
|
|
transpose_bm: bool | None = False,
|
|
config: dict | None = None,
|
|
) -> torch.Tensor:
|
|
# ruff: noqa: E501 # isort: skip
|
|
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
|
|
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
|
|
)
|
|
|
|
return aiter_triton_fp8_bmm(
|
|
X,
|
|
WQ,
|
|
w_scale,
|
|
group_size=group_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
splitK=splitK,
|
|
YQ=YQ,
|
|
transpose_bm=transpose_bm,
|
|
config=config,
|
|
)
|
|
|
|
@staticmethod
|
|
def triton_gemm_a8w8_blockscale(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: list[int],
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
|
|
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
|
|
|
@staticmethod
|
|
def per_1x128_fp8_quant(
|
|
input_2d: torch.Tensor,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
"""Only applies quantization method for fp8 data type only."""
|
|
from aiter import QuantType, dtypes, get_hip_quant
|
|
|
|
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
|
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
|
|
|
|
@staticmethod
|
|
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
|
|
return (n, k) in [
|
|
(1024, 8192),
|
|
(2112, 7168),
|
|
(3072, 1536),
|
|
(32768, 8192),
|
|
(4096, 7168),
|
|
(4608, 7168),
|
|
(512, 7168),
|
|
(7168, 2048),
|
|
(7168, 256),
|
|
(8192, 1024),
|
|
(8192, 32768),
|
|
]
|
|
|
|
@staticmethod
|
|
def shuffle_weight(
|
|
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
|
) -> torch.Tensor:
|
|
from aiter.ops.shuffle import shuffle_weight
|
|
|
|
return shuffle_weight(tensor, layout=layout)
|
|
|
|
@staticmethod
|
|
def shuffle_weights(
|
|
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
|
) -> tuple[torch.Tensor, ...]:
|
|
"""
|
|
Applies shuffle_weight function from AITER to each
|
|
input tensor and returns them.
|
|
|
|
Rearranges (shuffles) the input tensor/s
|
|
into a specified block layout for optimized computation.
|
|
|
|
Args:
|
|
*tensors: Variable number of torch.Tensor objects.
|
|
layout: A pair of integers specifying the block sizes used to divide
|
|
the tensors during shuffling. Default is (16, 16).
|
|
|
|
Returns:
|
|
A Tuple of shuffled tensors.
|
|
"""
|
|
from aiter.ops.shuffle import shuffle_weight
|
|
|
|
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
|
|
|
|
|
if IS_AITER_FOUND:
|
|
rocm_aiter_ops.register_ops_once()
|