[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
vllmellm 2025-11-10 17:20:53 +01:00 committed by GitHub
parent 40e2eeeb92
commit f080a83511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1193 additions and 924 deletions

View File

@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |

View File

@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`.
"""
import functools
import importlib
import sys
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
@ -412,14 +415,12 @@ def test_mixtral_moe(
huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
# Force reload aiter_ops to pick up the new environment variables.
if "rocm_aiter_ops" in sys.modules:
importlib.reload(rocm_aiter_ops)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")

View File

@ -4,6 +4,7 @@
import pytest
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (
@ -15,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_topk_func,
vllm_topk_softmax,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import (
RMSNorm,
dispatch_rocm_rmsnorm_func,
@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func()
is_rocm_aiter_moe_enabled.cache_clear()
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax,
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_topk_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_func(use_rocm_aiter)
assert topk_func == rocm_aiter_topk_softmax
if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_softmax
else:
assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
)
def test_rms_norm_dispatch(
add_residual: bool,
dtype: torch.dtype,
use_rocm_aiter: str,
use_rocm_aiter_norm: str,
monkeypatch,
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
should_use_rocm_aiter = (
current_platform.is_rocm()
and int(use_rocm_aiter)
and int(use_rocm_aiter_norm)
and use_rocm_aiter
and dtype in RMS_NORM_SUPPORTED_DTYPES
)
if add_residual and should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else:

941
vllm/_aiter_ops.py Normal file
View File

@ -0,0 +1,941 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def is_aiter_found() -> bool:
from importlib.util import find_spec
return find_spec("aiter") is not None
# `find_spec` is not torch.compile compatible.
# In cases where aiter availability might have
# been checked in forward passes that are torch compiled.
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
ROCm AITER package is supported on gfx9 archs.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existance.
from vllm.platforms.rocm import on_gfx9
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
return func(*args, **kwargs)
else:
# Return None or do nothing if not supported
return None
return wrapper
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
def _rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def _rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _rocm_aiter_topk_softmax_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
from aiter import topk_softmax
topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
def _rocm_aiter_topk_softmax_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
pass
def _rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import biased_grouped_topk
biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
def _rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def _rocm_aiter_grouped_topk_impl(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
is_softmax = scoring_func == "softmax"
from aiter import grouped_topk
grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
is_softmax,
routed_scaling_factor,
)
def _rocm_aiter_grouped_topk_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def _rocm_aiter_mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass
def _rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def _rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
def _rocm_aiter_gemm_w8a8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
def _rocm_aiter_gemm_w8a8_blockscale_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
def _rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from aiter import rms_norm
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rms_norm(x, weight, variance_epsilon)
def _rocm_aiter_rms_norm_fake(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.empty_like(x)
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import rmsnorm2d_fwd_with_add
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)
return output, residual_out
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
class rocm_aiter_ops:
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
@classmethod
@if_aiter_supported
def is_enabled(cls) -> bool:
"""Verifies device specs and availability of aiter main env variable."""
return cls._AITER_ENABLED
@classmethod
@if_aiter_supported
def is_linear_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
@classmethod
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
@classmethod
@if_aiter_supported
def is_rmsnorm_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod
@if_aiter_supported
def is_fused_moe_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod
@if_aiter_supported
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
@classmethod
@if_aiter_supported
def is_mla_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MLA_ENABLED
@classmethod
@if_aiter_supported
def is_mha_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MHA_ENABLED
@classmethod
@if_aiter_supported
def is_pa_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
@classmethod
@if_aiter_supported
def is_triton_unified_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
@classmethod
@if_aiter_supported
def is_fp8bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
@classmethod
@if_aiter_supported
def is_triton_rotary_embed_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
@staticmethod
@if_aiter_supported
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
tags = (
tuple()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=_rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=_rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=_rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=_rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=_rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=_rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=_rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=_rocm_aiter_mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
tags=tags,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=_rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
@staticmethod
def rms_norm2d_with_add(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
x, residual, weight, variance_epsilon
)
@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
@staticmethod
def gemm_w8a8(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
@staticmethod
def gemm_w8a8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
A, B, As, Bs, output_dtype
)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation_method,
quant_method,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
@staticmethod
def asm_moe_tkw1(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale,
fc2_scale,
fc1_smooth_scale,
fc2_smooth_scale,
a16,
per_tensor_quant_scale,
expert_mask,
activation_method,
)
@staticmethod
def topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_indices
@staticmethod
def biased_grouped_topk(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0,
) -> None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
@staticmethod
def grouped_topk(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> None:
torch.ops.vllm.rocm_aiter_grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
scoring_func,
routed_scaling_factor,
)
@staticmethod
def mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = torch.bfloat16,
x_scales: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
)
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
@staticmethod
def triton_rotary_embed(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
rope_cached_thd_positions_2c_fwd_inplace(
positions,
sin,
cos,
query_,
key_,
rotate_style,
reuse_freqs_front_part=True,
is_nope_first=False,
)
query = query.view(query_shape)
key = key.view(key_shape)
@staticmethod
def triton_fp8_bmm(
X: torch.Tensor,
WQ: torch.Tensor,
w_scale: torch.Tensor,
group_size: int = 128,
bias: torch.Tensor | None = None,
dtype: torch.dtype | None = torch.bfloat16,
splitK: int | None = None,
YQ: torch.Tensor | None = None,
transpose_bm: bool | None = False,
config: dict | None = None,
) -> torch.Tensor:
# ruff: noqa: E501 # isort: skip
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
)
return aiter_triton_fp8_bmm(
X,
WQ,
w_scale,
group_size=group_size,
bias=bias,
dtype=dtype,
splitK=splitK,
YQ=YQ,
transpose_bm=transpose_bm,
config=config,
)
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
@staticmethod
def per_1x128_fp8_quant(
input_2d: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Only applies quantization method for fp8 data type only."""
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
@staticmethod
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
@staticmethod
def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight
return shuffle_weight(tensor, layout=layout)
@staticmethod
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
rocm_aiter_ops.register_ops_once()

View File

@ -1,105 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def get_aiter_mla_metadata(
max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
) -> tuple[torch.Tensor, ...]:
paged_kv_indices = torch.zeros(
max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
)
paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
paged_kv_last_page_lens = torch.full(
(max_batch_size,), block_size, dtype=torch.int32
)
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass
if current_platform.is_rocm():
if is_torch_equal_or_newer("2.7.0"):
tags = ()
else:
tags = ((torch.Tag.needs_fixed_stride_order,),)
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=mla_decode_fwd_fake,
tags=tags,
)

View File

@ -109,7 +109,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
@ -926,8 +926,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Whether to use aiter rope.
# By default is disabled.
"VLLM_ROCM_USE_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1")
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
@ -1589,7 +1589,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_AITER_TRITON_GEMM",

View File

@ -14,6 +14,7 @@ import torch.nn.functional as F
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
return topk_weights, topk_indices
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
@ -1121,7 +1120,7 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func()
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)

View File

@ -13,6 +13,7 @@ import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
@ -41,8 +42,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import (
@ -92,13 +91,11 @@ else:
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk_aiter,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@ -620,13 +618,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
shuffle_weights,
)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@ -1002,6 +996,7 @@ def determine_expert_map(
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
num_fused_shared_experts: int = 0,
return_expert_mask: bool = False,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
"""
Calculates how many experts should be assigned to each rank for EP and
@ -1064,7 +1059,7 @@ def determine_expert_map(
)
expert_mask = None
if is_rocm_aiter_moe_enabled():
if return_expert_mask:
expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
)
@ -1292,14 +1287,18 @@ class FusedMoE(CustomOp):
self.logical_replica_count: torch.Tensor | None = None
# ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.aiter_fmoe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
self.num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None
and is_rocm_aiter_fusion_shared_expert_enabled()
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
else 0
)
if (
not is_rocm_aiter_fusion_shared_expert_enabled()
not self.aiter_fmoe_shared_expert_enabled
and self.num_fused_shared_experts != 0
):
raise ValueError(
@ -1346,6 +1345,7 @@ class FusedMoE(CustomOp):
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
@ -1570,13 +1570,16 @@ class FusedMoE(CustomOp):
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
)
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(),
dp_size=get_dp_group().world_size,
)
def _load_per_tensor_weight_scale(
self,
@ -1753,20 +1756,19 @@ class FusedMoE(CustomOp):
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
if is_rocm_aiter_fusion_shared_expert_enabled():
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
@overload
def weight_loader(
@ -2208,15 +2210,16 @@ class FusedMoE(CustomOp):
elif use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if is_rocm_aiter_moe_enabled():
if not is_rocm_aiter_fusion_shared_expert_enabled():
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert num_fused_shared_experts == 0
grouped_topk_impl = partial(
grouped_topk_aiter,
rocm_aiter_grouped_topk,
num_fused_shared_experts=num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
@ -2448,7 +2451,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
@ -2612,7 +2615,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,

View File

@ -1,17 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
from functools import cache, lru_cache
from functools import lru_cache
import torch
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum):
@ -37,27 +35,6 @@ class ActivationMethod(IntEnum):
GELU = 1
@cache
def is_rocm_aiter_moe_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_MOE
and envs.VLLM_ROCM_USE_AITER
)
@cache
def use_mxfp4_aiter_moe() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
@cache
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
return (
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
)
aiter_topK_meta_data = None
@ -114,250 +91,6 @@ def init_aiter_topK_meta_data(
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_topk_softmax_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
from aiter import topk_softmax
topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
def rocm_aiter_topk_softmax_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
pass
def rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import biased_grouped_topk
biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
def rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_grouped_topk_impl(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import grouped_topk
grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
scoring_func,
routed_scaling_factor,
)
def rocm_aiter_grouped_topk_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
def rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl,
fake_impl=rocm_aiter_fused_moe_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake,
)
def rocm_aiter_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk(
) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0]
device = hidden_states.device
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
assert aiter_topK_meta_data is not None, (
"AITER topK meta data is not initialized. "
"Please ensure that init_aiter_topK_meta_data "
@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
if e_score_correction_bias is not None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
rocm_aiter_ops.biased_grouped_topk(
gating_output,
e_score_correction_bias.to(gating_output.dtype),
topk_weights,
@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk(
)
else:
assert scoring_func == "softmax" or scoring_func == "sigmoid"
torch.ops.vllm.rocm_aiter_grouped_topk(
rocm_aiter_ops.grouped_topk(
gating_output,
topk_weights,
topk_ids,
@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk(
routed_scaling_factor=routed_scaling_factor,
)
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
return total_topk_weights, total_topk_ids
return topk_weights, topk_ids
@ -464,7 +203,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
return rocm_aiter_ops.asm_moe_tkw1(
hidden_states,
w1,
w2,
@ -482,7 +221,9 @@ def rocm_aiter_fused_experts(
else:
quant_method = QuantMethod.NO.value
# quark moe for mxfp4 w_dtype
if quant_config.use_mxfp4_w4a16:
quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, (
@ -507,7 +248,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
return torch.ops.vllm.rocm_aiter_fused_moe(
return rocm_aiter_ops.fused_moe(
hidden_states,
w1,
w2,
@ -522,39 +263,3 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input,
)
def rocm_aiter_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_indices
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)

View File

@ -6,18 +6,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
def rms_norm(
@ -58,80 +53,34 @@ def fused_add_rms_norm(
return x, residual
def rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
def poly_norm(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
import aiter as rocm_aiter
from vllm import _custom_ops as ops
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
out = torch.empty_like(x)
ops.poly_norm(
out,
x,
weight,
bias,
variance_epsilon,
)
return output, residual_out
return out
def rocm_aiter_rms_norm_fake(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.empty_like(x)
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
fake_impl=rocm_aiter_rms_norm_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
)
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
def dispatch_rocm_rmsnorm_func(
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
use_aiter = use_aiter and dtype in [
torch.float16,
torch.bfloat16,
]
if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter:
return torch.ops.vllm.rocm_aiter_rms_norm
return rocm_aiter_ops.rms_norm
# fall back to CUDA implementation
if with_fused_add:
@ -169,11 +118,14 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(self.weight)
if current_platform.is_rocm():
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False, dtype=weight_dtype
with_fused_add=False,
dtype=weight_dtype,
use_aiter=aiter_rmsnorm_enabled,
)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
)
@staticmethod

View File

@ -12,6 +12,7 @@ from compressed_tensors.quantization import ActivationOrdering, QuantizationStra
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@ -582,11 +583,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
@ -829,12 +827,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)

View File

@ -7,12 +7,12 @@ import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@ -369,7 +369,7 @@ class Fp8LinearMethod(LinearMethodBase):
if vllm_is_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
@ -869,12 +869,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
shuffle_weights,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
@ -916,7 +912,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@ -962,7 +958,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@ -1042,7 +1038,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)

View File

@ -4,54 +4,14 @@
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
fake_impl=rocm_aiter_gemm_w8a8_fake,
)
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@ -75,7 +35,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
+ "installed on ROCm.",
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
if not (rocm_aiter_ops.is_linear_enabled()):
return (
False,
"AiterScaledMMLinearKernel is disabled. "
@ -157,6 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return torch.ops.vllm.rocm_aiter_gemm_w8a8(
x_q, w_q.t(), x_s, w_s, bias, out_dtype
)
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@ -8,6 +8,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
@ -21,10 +22,6 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
use_mxfp4_aiter_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
)
@ -122,7 +119,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
if current_platform.is_rocm():
self.use_marlin = False
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
def create_weights(
self,
@ -309,12 +306,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@ -470,13 +463,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"not implemented. Please open an issue."
)
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
self.emulate = not current_platform.supports_mx() or not (
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
)
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
@ -656,28 +651,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
if not self.emulate:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
aiter_acts = {
ActivationType.No.name.lower(): ActivationType.No,
ActivationType.Silu.name.lower(): ActivationType.Silu,
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
}
assert activation in aiter_acts, (
f"Aiter CK fp4 MoE doesn't support activation {activation}"
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
out = fused_moe(
out = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=aiter_acts[activation],
doweight_stage1=False,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
quant_config=self.moe_quant_config,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts

View File

@ -31,6 +31,13 @@ from .quark_scheme import QuarkScheme
logger = init_logger(__name__)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registeration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (

View File

@ -12,6 +12,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -68,78 +69,6 @@ def cutlass_scaled_mm(
)
def rocm_aiter_gemm_w8a8_blockscale_impl(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
def is_aiter_triton_kernel_tuned(n, k):
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return gemm_a8w8_blockscale(
q_input, weight, input_scale, weight_scale, dtype=output_dtype
)
def rocm_aiter_gemm_w8a8_blockscale_fake(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = input_2d.shape[0]
n = weight.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
input_2d,
weight,
input_scale,
weight_scale,
self.act_quant_group_shape.col,
input_2d.dtype,
)
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
if (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
def _run_triton(
self,
@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm for MI3XX"""
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
)
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""

View File

@ -472,7 +472,7 @@ class Fp8LinearOp:
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)

View File

@ -4,13 +4,10 @@
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
from .rocm_aiter_rope_ops import (
is_rocm_triton_rotary_embedding_enabled,
rocm_aiter_rotary_emb,
)
@CustomOp.register("rotary_embedding")
@ -48,8 +45,8 @@ class RotaryEmbeddingBase(CustomOp):
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_triton_rotary_embedding_enabled = (
is_rocm_triton_rotary_embedding_enabled()
self.is_rocm_triton_rotary_embed_enabled = (
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
@ -169,9 +166,9 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.is_rocm_triton_rotary_embedding_enabled:
if self.is_rocm_triton_rotary_embed_enabled:
self._match_cos_sin_cache_dtype(query)
rocm_aiter_rotary_emb(
rocm_aiter_ops.triton_rotary_embed(
positions,
query,
key,

View File

@ -146,6 +146,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key = key_rot
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
positions: torch.Tensor,

View File

@ -1,94 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_TRITON_ROPE
)
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter.ops.triton.rope as ops
ops.rope_cached_thd_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass
if is_rocm_triton_rotary_embedding_enabled():
direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
dispatch_key=current_platform.dispatch_key,
)
def rocm_aiter_rotary_emb(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
positions,
sin,
cos,
query_,
key_,
rotate_style,
False,
)
query = query.view(query_shape)
key = key.view(key_shape)

View File

@ -33,6 +33,7 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
@ -50,10 +51,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@ -294,10 +291,8 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_start + self.n_local_physical_experts
)
if (
config.n_shared_experts is None
or is_rocm_aiter_fusion_shared_expert_enabled()
):
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@ -330,14 +325,14 @@ class DeepseekV2MoE(nn.Module):
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not is_rocm_aiter_moe_enabled()
if not self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
else None,
)
@ -371,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not is_rocm_aiter_moe_enabled():
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
@ -1428,6 +1423,9 @@ class DeepseekV2ForCausalLM(
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
@ -1456,7 +1454,7 @@ class DeepseekV2ForCausalLM(
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_moe_shared_expert_enabled
else 0
),
num_redundant_experts=self.num_redundant_experts,
@ -1472,9 +1470,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = (
is_rocm_aiter_fusion_shared_expert_enabled()
and ("mlp.shared_experts" in name)
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
"mlp.shared_experts" in name
)
for param_name, weight_name, shard_id in stacked_params_mapping:

View File

@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
and not (rocm_aiter_ops.is_pa_attn_enabled())
and sinks is None
)
@ -202,12 +204,15 @@ class RocmPlatform(Platform):
]
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
@ -228,19 +233,23 @@ class RocmPlatform(Platform):
has_sink,
use_sparse,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
if use_mla:
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled,
if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
if use_mla:
if selected_backend is None:
selected_backend = (
_Backend.ROCM_AITER_MLA
if is_aiter_mla_enabled() or block_size == 1
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA
)
@ -265,12 +274,12 @@ class RocmPlatform(Platform):
logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return (

View File

@ -198,6 +198,7 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
@ -270,28 +271,15 @@ except ImportError:
flashinfer_available = False
def is_rocm_aiter_fp8bmm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP8BMM
and envs.VLLM_ROCM_USE_AITER
)
if is_rocm_aiter_fp8bmm_enabled():
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501
)
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
logger = init_logger(__name__)
@ -1109,6 +1097,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
@ -1158,7 +1147,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
@ -1187,7 +1176,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
@ -1196,7 +1185,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
@ -1208,10 +1197,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
# Convert from (B, N, V) to (B, N * V)
@ -1571,7 +1559,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
@ -1600,7 +1588,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
@ -1609,7 +1597,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
@ -1958,7 +1946,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
B, N, L = decode_q_pe.shape
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
@ -1966,9 +1953,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
decode_q_nope,
self.W_K,
self.W_K_scale,

View File

@ -6,9 +6,8 @@ from typing import ClassVar
import torch
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import (
@ -22,10 +21,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
@ -284,7 +279,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(
rocm_aiter_ops.mla_decode_fwd(
q,
kv_buffer,
o,