mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 16:09:42 +08:00
Signed-off-by: c0de128 <kevin.mckay@outlook.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1340 lines
41 KiB
Python
1340 lines
41 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import functools
|
|
from collections.abc import Callable
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
|
|
|
_FP8_DTYPE = current_platform.fp8_dtype()
|
|
|
|
|
|
def is_aiter_found() -> bool:
|
|
from importlib.util import find_spec
|
|
|
|
return find_spec("aiter") is not None
|
|
|
|
|
|
# `find_spec` is not torch.compile compatible.
|
|
# In cases where aiter availability might have
|
|
# been checked in forward passes that are torch compiled.
|
|
# we keep this global outside to not cause torch compile breaks.
|
|
IS_AITER_FOUND = is_aiter_found()
|
|
|
|
|
|
def is_aiter_found_and_supported() -> bool:
|
|
if current_platform.is_rocm() and IS_AITER_FOUND:
|
|
from vllm.platforms.rocm import on_gfx9
|
|
|
|
return on_gfx9()
|
|
return False
|
|
|
|
|
|
def if_aiter_supported(func: Callable) -> Callable:
|
|
"""Decorator that only executes the function if
|
|
ROCm AITER package is supported on gfx9 archs.
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
# checks the platform, device arch and aiter library existence.
|
|
|
|
if is_aiter_found_and_supported():
|
|
return func(*args, **kwargs)
|
|
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
|
|
# Can't use dtypes.fp8 directly inside an op
|
|
# because it returns wrong result on gfx942.
|
|
# This is a workaround to get the correct FP8 dtype.
|
|
# This might because that the get_gfx() is wrapped as a custom op.
|
|
if is_aiter_found_and_supported():
|
|
from aiter import dtypes
|
|
|
|
AITER_FP8_DTYPE = dtypes.fp8
|
|
|
|
|
|
def _rocm_aiter_fused_moe_impl(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
quant_method: int = 0,
|
|
doweight_stage1: bool = False,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
from aiter import ActivationType, QuantType
|
|
from aiter.fused_moe import fused_moe
|
|
|
|
activation = ActivationType(activation_method)
|
|
quant_type = QuantType(quant_method)
|
|
|
|
return fused_moe(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
expert_mask,
|
|
activation,
|
|
quant_type,
|
|
doweight_stage1,
|
|
w1_scale,
|
|
w2_scale,
|
|
a1_scale,
|
|
a2_scale,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_fused_moe_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
quant_method: int = 0,
|
|
doweight_stage1: bool = False,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(hidden_states)
|
|
|
|
|
|
def _rocm_aiter_asm_moe_tkw1_impl(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
fc1_scale: torch.Tensor | None = None,
|
|
fc2_scale: torch.Tensor | None = None,
|
|
fc1_smooth_scale: torch.Tensor | None = None,
|
|
fc2_smooth_scale: torch.Tensor | None = None,
|
|
a16: bool = False,
|
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
) -> torch.Tensor:
|
|
from aiter import ActivationType
|
|
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
|
|
|
activation = ActivationType(activation_method)
|
|
|
|
return asm_moe_tkw1(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
fc1_scale=fc1_scale,
|
|
fc2_scale=fc2_scale,
|
|
fc1_smooth_scale=fc1_smooth_scale,
|
|
fc2_smooth_scale=fc2_smooth_scale,
|
|
a16=a16,
|
|
per_tensor_quant_scale=per_tensor_quant_scale,
|
|
expert_mask=expert_mask,
|
|
activation=activation,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_asm_moe_tkw1_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
fc1_scale: torch.Tensor | None = None,
|
|
fc2_scale: torch.Tensor | None = None,
|
|
fc1_smooth_scale: torch.Tensor | None = None,
|
|
fc2_smooth_scale: torch.Tensor | None = None,
|
|
a16: bool = False,
|
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(hidden_states)
|
|
|
|
|
|
def _rocm_aiter_topk_softmax_impl(
|
|
topk_weights: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool,
|
|
) -> None:
|
|
from aiter import topk_softmax
|
|
|
|
topk_softmax(
|
|
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
|
)
|
|
|
|
|
|
def _rocm_aiter_topk_softmax_fake(
|
|
topk_weights: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def _rocm_aiter_biased_grouped_topk_impl(
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
from aiter import biased_grouped_topk
|
|
|
|
biased_grouped_topk(
|
|
gating_output,
|
|
correction_bias,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_biased_grouped_topk_fake(
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def _rocm_aiter_grouped_topk_impl(
|
|
gating_output: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
is_softmax = scoring_func == "softmax"
|
|
from aiter import grouped_topk
|
|
|
|
grouped_topk(
|
|
gating_output,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
is_softmax,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_grouped_topk_fake(
|
|
gating_output: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
# Cache whether aiter supports FP8 MLA parameters
|
|
_AITER_MLA_SUPPORTS_FP8: bool | None = None
|
|
|
|
|
|
def _check_aiter_mla_fp8_support() -> bool:
|
|
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
|
|
global _AITER_MLA_SUPPORTS_FP8
|
|
if _AITER_MLA_SUPPORTS_FP8 is None:
|
|
try:
|
|
import inspect
|
|
|
|
from aiter.mla import mla_decode_fwd
|
|
|
|
sig = inspect.signature(mla_decode_fwd)
|
|
_AITER_MLA_SUPPORTS_FP8 = (
|
|
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
|
|
)
|
|
except Exception:
|
|
_AITER_MLA_SUPPORTS_FP8 = False
|
|
return _AITER_MLA_SUPPORTS_FP8
|
|
|
|
|
|
def _rocm_aiter_mla_decode_fwd_impl(
|
|
q: torch.Tensor,
|
|
kv_buffer: torch.Tensor,
|
|
o: torch.Tensor,
|
|
qo_indptr: torch.Tensor,
|
|
max_seqlen_qo: int,
|
|
kv_indptr: torch.Tensor | None = None,
|
|
kv_indices: torch.Tensor | None = None,
|
|
kv_last_page_lens: torch.Tensor | None = None,
|
|
sm_scale: float = 1.0,
|
|
logit_cap: float = 0.0,
|
|
q_scale: torch.Tensor | None = None,
|
|
kv_scale: torch.Tensor | None = None,
|
|
) -> None:
|
|
from aiter.mla import mla_decode_fwd
|
|
|
|
kwargs = {
|
|
"sm_scale": sm_scale,
|
|
"logit_cap": logit_cap,
|
|
}
|
|
|
|
# Only pass q_scale and kv_scale if the aiter library supports them
|
|
if _check_aiter_mla_fp8_support():
|
|
kwargs["q_scale"] = q_scale
|
|
kwargs["kv_scale"] = kv_scale
|
|
|
|
mla_decode_fwd(
|
|
q,
|
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
|
o,
|
|
qo_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_lens,
|
|
max_seqlen_qo,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_mla_decode_fwd_fake(
|
|
q: torch.Tensor,
|
|
kv_buffer: torch.Tensor,
|
|
o: torch.Tensor,
|
|
qo_indptr: torch.Tensor,
|
|
max_seqlen_qo: int,
|
|
kv_indptr: torch.Tensor | None = None,
|
|
kv_indices: torch.Tensor | None = None,
|
|
kv_last_page_lens: torch.Tensor | None = None,
|
|
sm_scale: float = 1.0,
|
|
logit_cap: float = 0.0,
|
|
q_scale: torch.Tensor | None = None,
|
|
kv_scale: torch.Tensor | None = None,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def _rocm_aiter_gemm_a8w8_impl(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
from aiter import gemm_a8w8_CK
|
|
|
|
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
|
# a to be [M, K]
|
|
# b to be [N, K]
|
|
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
|
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
|
|
|
|
|
def _rocm_aiter_gemm_a8w8_fake(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
m = A.shape[0]
|
|
n = B.shape[0]
|
|
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
|
return Y
|
|
|
|
|
|
def _rocm_aiter_gemm_a8w8_blockscale_impl(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
from aiter import gemm_a8w8_blockscale
|
|
|
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
|
|
|
|
|
def _rocm_aiter_gemm_a8w8_blockscale_fake(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
m = A.shape[0]
|
|
n = B.shape[0]
|
|
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
|
return Y
|
|
|
|
|
|
def _rocm_aiter_rms_norm_impl(
|
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
) -> torch.Tensor:
|
|
from aiter import rms_norm
|
|
|
|
if x.dim() > 2:
|
|
x_original_shape = x.shape
|
|
x = x.reshape(-1, x_original_shape[-1])
|
|
x = rms_norm(x, weight, variance_epsilon)
|
|
return x.reshape(x_original_shape)
|
|
|
|
return rms_norm(x, weight, variance_epsilon)
|
|
|
|
|
|
def _rocm_aiter_rms_norm_fake(
|
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(x)
|
|
|
|
|
|
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from aiter import rmsnorm2d_fwd_with_add
|
|
|
|
residual_out = torch.empty_like(residual)
|
|
output = torch.empty_like(x)
|
|
rmsnorm2d_fwd_with_add(
|
|
output, # output
|
|
x, # input
|
|
residual, # residual input
|
|
residual_out, # residual output
|
|
weight,
|
|
variance_epsilon,
|
|
)
|
|
return output, residual_out
|
|
|
|
|
|
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.empty_like(x), torch.empty_like(residual)
|
|
|
|
|
|
def _rocm_aiter_per_tensor_quant_impl(
|
|
x: torch.Tensor,
|
|
quant_dtype: torch.dtype,
|
|
scale: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from aiter.ops.quant import per_tensor_quant_hip
|
|
|
|
return per_tensor_quant_hip(x, scale, quant_dtype)
|
|
|
|
|
|
def _rocm_aiter_per_tensor_quant_fake(
|
|
x: torch.Tensor,
|
|
quant_dtype: torch.dtype,
|
|
scale: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
|
|
1, dtype=torch.float32, device=x.device
|
|
)
|
|
|
|
|
|
def _rocm_aiter_per_token_quant_impl(
|
|
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from aiter.ops.quant import dynamic_per_token_scaled_quant
|
|
|
|
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
|
|
|
out_shape = x.shape
|
|
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
|
|
if scale is None:
|
|
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
|
|
dynamic_per_token_scaled_quant(
|
|
out,
|
|
x,
|
|
scale,
|
|
scale_ub=None,
|
|
shuffle_scale=False,
|
|
num_rows=None,
|
|
num_rows_factor=1,
|
|
)
|
|
return out, scale
|
|
|
|
|
|
def _rocm_aiter_per_token_quant_fake(
|
|
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
out_shape = x.shape
|
|
return (
|
|
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
|
|
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
|
|
)
|
|
|
|
|
|
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
|
|
|
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
|
x,
|
|
weight,
|
|
variance_epsilon,
|
|
None,
|
|
None,
|
|
None,
|
|
group_size=group_size,
|
|
dtype_quant=AITER_FP8_DTYPE,
|
|
res1=residual,
|
|
)
|
|
return (x_quant, x_quant_scales, res)
|
|
|
|
|
|
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
M, N = x.shape
|
|
scale_shape = (M, (N + group_size - 1) // group_size)
|
|
return (
|
|
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
|
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
|
torch.empty_like(residual, device=residual.device),
|
|
)
|
|
|
|
|
|
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
|
|
|
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
|
x,
|
|
weight,
|
|
variance_epsilon,
|
|
None,
|
|
None,
|
|
None,
|
|
group_size=group_size,
|
|
dtype_quant=AITER_FP8_DTYPE,
|
|
res1=None,
|
|
)
|
|
return (x_quant, x_quant_scales)
|
|
|
|
|
|
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
M, N = x.shape
|
|
scale_shape = (M, (N + group_size - 1) // group_size)
|
|
return (
|
|
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
|
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
|
)
|
|
|
|
|
|
def _rocm_aiter_group_fp8_quant_impl(
|
|
x: torch.Tensor,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
|
from aiter import QuantType, get_hip_quant
|
|
|
|
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
|
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
|
|
|
|
|
|
def _rocm_aiter_group_fp8_quant_fake(
|
|
x: torch.Tensor,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
M, N = x.shape
|
|
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
|
|
out_bs = torch.empty(
|
|
(
|
|
M,
|
|
(N + group_size - 1) // group_size,
|
|
),
|
|
dtype=torch.float32,
|
|
device=x.device,
|
|
)
|
|
return x_fp8, out_bs
|
|
|
|
|
|
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
|
|
x: torch.Tensor,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
|
|
|
|
return act_mul_and_fp8_group_quant(
|
|
x,
|
|
activation="silu",
|
|
group_size=group_size,
|
|
dtype_quant=AITER_FP8_DTYPE,
|
|
)
|
|
|
|
|
|
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
|
|
x: torch.Tensor,
|
|
group_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
M, N = x.shape
|
|
assert N % 2 == 0
|
|
N_half = N // 2
|
|
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
|
|
out_bs = torch.empty(
|
|
(
|
|
M,
|
|
(N_half + group_size - 1) // group_size,
|
|
),
|
|
dtype=torch.float32,
|
|
device=x.device,
|
|
)
|
|
return x_fp8, out_bs
|
|
|
|
|
|
# Global flag to ensure ops are registered only once
|
|
_OPS_REGISTERED = False
|
|
|
|
|
|
class rocm_aiter_ops:
|
|
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
|
|
|
|
This class centralizes the import and registration of AITER ops,
|
|
and provides a unified interface for checking if AITER is enabled.
|
|
Operations are only available on supported gfx9
|
|
architectures when aiter is installed.
|
|
|
|
The class uses environment variables to control which features are enabled,
|
|
allowing fine-grained control over which AITER optimizations are used.
|
|
|
|
Environment Variables:
|
|
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
|
|
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
|
|
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
|
|
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
|
|
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
|
|
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
|
|
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
|
|
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
|
|
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
|
|
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
|
|
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
|
|
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
|
|
|
|
Note:
|
|
The environment variables are assigned when the module is imported,
|
|
so you can't change the environment variables after the module is imported.
|
|
This is done out of performance consideration. Accessing environment variables
|
|
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
|
|
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
|
|
You can call the refresh_env_variables() function to reload the env variables
|
|
after monkey patching the env variables in the unit test.
|
|
|
|
Check Functions:
|
|
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
|
|
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
|
|
(3) aiter library is installed. The check function then also verifies
|
|
the corresponding environment variable is enabled.
|
|
i.e. ___
|
|
is_enabled() == current_platform.is_rocm() and | checked by
|
|
current_platform.is_on_gfx9() and | @if_aiter_supported
|
|
IS_AITER_FOUND and _______________|
|
|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
|
|
|
|
Example:
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
|
|
|
# Check if aiter is enabled before using operations
|
|
if rocm_aiter_ops.is_enabled():
|
|
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
|
|
|
|
Operations:
|
|
- RMS normalization: rms_norm, rms_norm2d_with_add
|
|
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
|
|
- Fused MoE: fused_moe, asm_moe_tkw1
|
|
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
|
|
- MLA decode: mla_decode_fwd
|
|
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
|
|
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
|
|
"""
|
|
|
|
# Check if the env variable is set
|
|
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
|
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
|
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
|
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
|
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
|
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
|
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
|
# TODO: Consolidate under _LINEAR_ENABLED
|
|
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
|
# TODO: Consolidate under _LINEAR_ENABLED
|
|
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
|
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
|
|
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
|
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
|
# TODO: Consolidate under _LINEAR_ENABLED
|
|
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
|
|
|
@classmethod
|
|
def refresh_env_variables(cls):
|
|
"""
|
|
Since the environment variables are assigned when the module is imported,
|
|
This is a helper function to reload all the env variables from
|
|
the environment variables.
|
|
for example, after monkey patching the env variables in the unit test,
|
|
you can call this function to reload the env variables.
|
|
"""
|
|
cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
|
cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
|
cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
|
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
|
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
|
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
|
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
|
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
|
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
|
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
|
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
|
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_linear_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_linear_fp8_enabled(cls) -> bool:
|
|
return cls.is_linear_enabled()
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_rmsnorm_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_fused_moe_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
|
|
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_mla_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_mha_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_triton_unified_attn_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_fp8bmm_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_triton_rotary_embed_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
|
|
|
@classmethod
|
|
@if_aiter_supported
|
|
def is_triton_gemm_enabled(cls) -> bool:
|
|
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM
|
|
|
|
@staticmethod
|
|
@if_aiter_supported
|
|
def register_ops_once() -> None:
|
|
global _OPS_REGISTERED
|
|
if not _OPS_REGISTERED:
|
|
tags = (
|
|
tuple()
|
|
if is_torch_equal_or_newer("2.7.0")
|
|
else (torch.Tag.needs_fixed_stride_order,)
|
|
)
|
|
|
|
# register all the custom ops here
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_asm_moe_tkw1",
|
|
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_fused_moe",
|
|
op_func=_rocm_aiter_fused_moe_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_fused_moe_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_topk_softmax",
|
|
op_func=_rocm_aiter_topk_softmax_impl,
|
|
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
|
fake_impl=_rocm_aiter_topk_softmax_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_biased_grouped_topk",
|
|
op_func=_rocm_aiter_biased_grouped_topk_impl,
|
|
mutates_args=["topk_weights", "topk_ids"],
|
|
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_grouped_topk",
|
|
op_func=_rocm_aiter_grouped_topk_impl,
|
|
mutates_args=["topk_weights", "topk_ids"],
|
|
fake_impl=_rocm_aiter_grouped_topk_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_mla_decode_fwd",
|
|
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
|
mutates_args=["o"],
|
|
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
|
|
tags=tags,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_gemm_a8w8",
|
|
op_func=_rocm_aiter_gemm_a8w8_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_gemm_a8w8_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
|
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
|
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_rms_norm",
|
|
op_func=_rocm_aiter_rms_norm_impl,
|
|
fake_impl=_rocm_aiter_rms_norm_fake,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
|
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
|
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
|
|
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
|
|
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
|
|
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
|
|
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
|
|
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
|
|
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_group_fp8_quant",
|
|
op_func=_rocm_aiter_group_fp8_quant_impl,
|
|
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_per_tensor_quant",
|
|
op_func=_rocm_aiter_per_tensor_quant_impl,
|
|
mutates_args=[],
|
|
fake_impl=_rocm_aiter_per_tensor_quant_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_aiter_per_token_quant",
|
|
op_func=_rocm_aiter_per_token_quant_impl,
|
|
mutates_args=["scale"],
|
|
fake_impl=_rocm_aiter_per_token_quant_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
_OPS_REGISTERED = True
|
|
|
|
@staticmethod
|
|
def rms_norm2d_with_add(
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
variance_epsilon: float,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
|
|
x, residual, weight, variance_epsilon
|
|
)
|
|
|
|
@staticmethod
|
|
def rms_norm(
|
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
|
|
|
@staticmethod
|
|
def gemm_a8w8(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
|
|
|
|
@staticmethod
|
|
def gemm_a8w8_blockscale(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: list[int],
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale(
|
|
A, B, As, Bs, output_dtype
|
|
)
|
|
|
|
@staticmethod
|
|
def fused_moe(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
quant_method: int = 0,
|
|
doweight_stage1: bool = False,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_fused_moe(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
expert_mask,
|
|
activation_method,
|
|
quant_method,
|
|
doweight_stage1,
|
|
w1_scale,
|
|
w2_scale,
|
|
a1_scale,
|
|
a2_scale,
|
|
)
|
|
|
|
@staticmethod
|
|
def asm_moe_tkw1(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
fc1_scale: torch.Tensor | None = None,
|
|
fc2_scale: torch.Tensor | None = None,
|
|
fc1_smooth_scale: torch.Tensor | None = None,
|
|
fc2_smooth_scale: torch.Tensor | None = None,
|
|
a16: bool = False,
|
|
per_tensor_quant_scale: torch.Tensor | None = None,
|
|
expert_mask: torch.Tensor | None = None,
|
|
activation_method: int = 0,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
fc1_scale,
|
|
fc2_scale,
|
|
fc1_smooth_scale,
|
|
fc2_smooth_scale,
|
|
a16,
|
|
per_tensor_quant_scale,
|
|
expert_mask,
|
|
activation_method,
|
|
)
|
|
|
|
@staticmethod
|
|
def topk_softmax(
|
|
topk_weights: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
torch.ops.vllm.rocm_aiter_topk_softmax(
|
|
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
|
)
|
|
return topk_weights, topk_indices
|
|
|
|
@staticmethod
|
|
def biased_grouped_topk(
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
routed_scaling_factor: float = 1.0,
|
|
) -> None:
|
|
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
|
gating_output,
|
|
correction_bias,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
@staticmethod
|
|
def grouped_topk(
|
|
gating_output: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
need_renorm: bool,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
) -> None:
|
|
torch.ops.vllm.rocm_aiter_grouped_topk(
|
|
gating_output,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
need_renorm,
|
|
scoring_func,
|
|
routed_scaling_factor,
|
|
)
|
|
|
|
@staticmethod
|
|
def mla_decode_fwd(
|
|
q: torch.Tensor,
|
|
kv_buffer: torch.Tensor,
|
|
o: torch.Tensor,
|
|
sm_scale: float,
|
|
qo_indptr: torch.Tensor,
|
|
max_seqlen_qo: int,
|
|
kv_indptr: torch.Tensor | None = None,
|
|
kv_indices: torch.Tensor | None = None,
|
|
kv_last_page_lens: torch.Tensor | None = None,
|
|
logit_cap: float = 0.0,
|
|
q_scale: torch.Tensor | None = None,
|
|
kv_scale: torch.Tensor | None = None,
|
|
):
|
|
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
|
q,
|
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
|
o,
|
|
qo_indptr,
|
|
max_seqlen_qo,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_lens,
|
|
sm_scale=sm_scale,
|
|
logit_cap=logit_cap,
|
|
q_scale=q_scale,
|
|
kv_scale=kv_scale,
|
|
)
|
|
|
|
@staticmethod
|
|
def per_tensor_quant(
|
|
x: torch.Tensor,
|
|
quant_dtype: torch.dtype,
|
|
scale: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
|
|
|
|
@staticmethod
|
|
def per_token_quant(
|
|
x: torch.Tensor,
|
|
quant_dtype: torch.dtype,
|
|
scale: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
|
|
|
|
@staticmethod
|
|
def triton_fp4_gemm_dynamic_qaunt(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
out_dtype: torch.dtype | None = torch.bfloat16,
|
|
x_scales: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
|
|
|
if x_scales is None:
|
|
x_q, x_s = dynamic_mxfp4_quant(x)
|
|
else:
|
|
x_q = x
|
|
x_s = x_scales
|
|
|
|
y = torch.empty(
|
|
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
|
|
)
|
|
|
|
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def triton_rotary_embed(
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
head_size: int,
|
|
rotary_dim: int,
|
|
is_neox_style: bool,
|
|
):
|
|
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
|
|
|
|
num_tokens = positions.numel()
|
|
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
|
query_shape = query.shape
|
|
key_shape = key.shape
|
|
rotate_style = 0 if is_neox_style else 1
|
|
|
|
query = query.view(num_tokens, -1, head_size)
|
|
key = key.view(num_tokens, -1, head_size)
|
|
query_ = query[..., :rotary_dim]
|
|
key_ = key[..., :rotary_dim]
|
|
positions = positions.view(*query.shape[:1])
|
|
rope_cached_thd_positions_2c_fwd_inplace(
|
|
positions,
|
|
sin,
|
|
cos,
|
|
query_,
|
|
key_,
|
|
rotate_style,
|
|
reuse_freqs_front_part=True,
|
|
is_nope_first=False,
|
|
)
|
|
query = query.view(query_shape)
|
|
key = key.view(key_shape)
|
|
|
|
@staticmethod
|
|
def triton_fp8_bmm(
|
|
X: torch.Tensor,
|
|
WQ: torch.Tensor,
|
|
w_scale: torch.Tensor,
|
|
group_size: int = 128,
|
|
bias: torch.Tensor | None = None,
|
|
dtype: torch.dtype | None = torch.bfloat16,
|
|
splitK: int | None = None,
|
|
YQ: torch.Tensor | None = None,
|
|
transpose_bm: bool | None = False,
|
|
config: dict | None = None,
|
|
) -> torch.Tensor:
|
|
# ruff: noqa: E501 # isort: skip
|
|
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
|
|
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
|
|
)
|
|
|
|
return aiter_triton_fp8_bmm(
|
|
X,
|
|
WQ,
|
|
w_scale,
|
|
group_size=group_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
splitK=splitK,
|
|
YQ=YQ,
|
|
transpose_bm=transpose_bm,
|
|
config=config,
|
|
)
|
|
|
|
@staticmethod
|
|
def triton_gemm_a8w8_blockscale(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: list[int],
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
|
|
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
|
|
|
@staticmethod
|
|
def group_fp8_quant(
|
|
input_2d: torch.Tensor,
|
|
group_size: int = 128,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
assert group_size == 128, "Group size must be 128"
|
|
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
|
|
|
|
@staticmethod
|
|
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
|
|
return (n, k) in [
|
|
(1024, 8192),
|
|
(2112, 7168),
|
|
(3072, 1536),
|
|
(32768, 8192),
|
|
(4096, 7168),
|
|
(4608, 7168),
|
|
(512, 7168),
|
|
(7168, 2048),
|
|
(7168, 256),
|
|
(8192, 1024),
|
|
(8192, 32768),
|
|
]
|
|
|
|
@staticmethod
|
|
def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
|
|
return (n, k) in [
|
|
(8192, 4096),
|
|
(1280, 8192),
|
|
(16384, 53248),
|
|
(106496, 16384),
|
|
(57344, 8192),
|
|
(8192, 2048),
|
|
(2560, 8192),
|
|
(10240, 8192),
|
|
(16384, 16384),
|
|
(8192, 28672),
|
|
(28672, 8192),
|
|
(18432, 16384),
|
|
(8192, 1024),
|
|
(7168, 8192),
|
|
(5120, 8192),
|
|
(8192, 8192),
|
|
(8192, 7168),
|
|
(14336, 8192),
|
|
(8192, 14336),
|
|
(8192, 3584),
|
|
]
|
|
|
|
@staticmethod
|
|
def shuffle_weight(
|
|
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
|
) -> torch.Tensor:
|
|
from aiter.ops.shuffle import shuffle_weight
|
|
|
|
return shuffle_weight(tensor, layout=layout)
|
|
|
|
@staticmethod
|
|
def shuffle_weights(
|
|
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
|
) -> tuple[torch.Tensor, ...]:
|
|
"""
|
|
Applies shuffle_weight function from AITER to each
|
|
input tensor and returns them.
|
|
|
|
Rearranges (shuffles) the input tensor/s
|
|
into a specified block layout for optimized computation.
|
|
|
|
Args:
|
|
*tensors: Variable number of torch.Tensor objects.
|
|
layout: A pair of integers specifying the block sizes used to divide
|
|
the tensors during shuffling. Default is (16, 16).
|
|
|
|
Returns:
|
|
A Tuple of shuffled tensors.
|
|
"""
|
|
from aiter.ops.shuffle import shuffle_weight
|
|
|
|
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
|
|
|
|
|
rocm_aiter_ops.register_ops_once()
|