mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 02:47:03 +08:00
[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
40e2eeeb92
commit
f080a83511
@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
||||
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
||||
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] |
|
||||
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
|
||||
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
|
||||
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -412,14 +415,12 @@ def test_mixtral_moe(
|
||||
huggingface."""
|
||||
|
||||
# clear the cache before every test
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
# Force reload aiter_ops to pick up the new environment variables.
|
||||
if "rocm_aiter_ops" in sys.modules:
|
||||
importlib.reload(rocm_aiter_ops)
|
||||
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
if dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (
|
||||
@ -15,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
dispatch_topk_func,
|
||||
vllm_topk_softmax,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
topk_func = dispatch_topk_func()
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_topk_softmax,
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
def test_topk_dispatch(use_rocm_aiter: bool):
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter)
|
||||
|
||||
assert topk_func == rocm_aiter_topk_softmax
|
||||
if current_platform.is_rocm() and use_rocm_aiter:
|
||||
assert topk_func == rocm_aiter_ops.topk_softmax
|
||||
else:
|
||||
assert topk_func == vllm_topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
|
||||
)
|
||||
def test_rms_norm_dispatch(
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
use_rocm_aiter: str,
|
||||
use_rocm_aiter_norm: str,
|
||||
monkeypatch,
|
||||
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
|
||||
|
||||
should_use_rocm_aiter = (
|
||||
current_platform.is_rocm()
|
||||
and int(use_rocm_aiter)
|
||||
and int(use_rocm_aiter_norm)
|
||||
and use_rocm_aiter
|
||||
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
)
|
||||
|
||||
if add_residual and should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
|
||||
elif should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm
|
||||
elif add_residual:
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
else:
|
||||
|
||||
941
vllm/_aiter_ops.py
Normal file
941
vllm/_aiter_ops.py
Normal file
@ -0,0 +1,941 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def is_aiter_found() -> bool:
|
||||
from importlib.util import find_spec
|
||||
|
||||
return find_spec("aiter") is not None
|
||||
|
||||
|
||||
# `find_spec` is not torch.compile compatible.
|
||||
# In cases where aiter availability might have
|
||||
# been checked in forward passes that are torch compiled.
|
||||
# we keep this global outside to not cause torch compile breaks.
|
||||
IS_AITER_FOUND = is_aiter_found()
|
||||
|
||||
|
||||
def if_aiter_supported(func: Callable) -> Callable:
|
||||
"""Decorator that only executes the function if
|
||||
ROCm AITER package is supported on gfx9 archs.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# checks the platform, device arch and aiter library existance.
|
||||
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
# Return None or do nothing if not supported
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
quant_method: int = 0,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
quant_type = QuantType(quant_method)
|
||||
|
||||
return fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
expert_mask,
|
||||
activation,
|
||||
quant_type,
|
||||
doweight_stage1,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
quant_method: int = 0,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def _rocm_aiter_asm_moe_tkw1_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
|
||||
return asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=fc1_scale,
|
||||
fc2_scale=fc2_scale,
|
||||
fc1_smooth_scale=fc1_smooth_scale,
|
||||
fc2_smooth_scale=fc2_smooth_scale,
|
||||
a16=a16,
|
||||
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||
expert_mask=expert_mask,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_asm_moe_tkw1_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def _rocm_aiter_topk_softmax_impl(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
from aiter import topk_softmax
|
||||
|
||||
topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_topk_softmax_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_biased_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
from aiter import biased_grouped_topk
|
||||
|
||||
biased_grouped_topk(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_biased_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
is_softmax = scoring_func == "softmax"
|
||||
from aiter import grouped_topk
|
||||
|
||||
grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
is_softmax,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import gemm_a8w8_CK
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
def _rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
from aiter import rms_norm
|
||||
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
x = rms_norm(x, weight, variance_epsilon)
|
||||
return x.reshape(x_original_shape)
|
||||
|
||||
return rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def _rocm_aiter_rms_norm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter import rmsnorm2d_fwd_with_add
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
output = torch.empty_like(x)
|
||||
rmsnorm2d_fwd_with_add(
|
||||
output, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual_out, # residual output
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
|
||||
class rocm_aiter_ops:
|
||||
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_enabled(cls) -> bool:
|
||||
"""Verifies device specs and availability of aiter main env variable."""
|
||||
return cls._AITER_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_fp8_enaled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_rmsnorm_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fused_moe_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
|
||||
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mla_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mha_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_pa_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fp8bmm_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_rotary_embed_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
||||
|
||||
@staticmethod
|
||||
@if_aiter_supported
|
||||
def register_ops_once() -> None:
|
||||
global _OPS_REGISTERED
|
||||
if not _OPS_REGISTERED:
|
||||
tags = (
|
||||
tuple()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
)
|
||||
|
||||
# register all the custom ops here
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_moe",
|
||||
op_func=_rocm_aiter_fused_moe_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_fused_moe_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_topk_softmax",
|
||||
op_func=_rocm_aiter_topk_softmax_impl,
|
||||
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||
fake_impl=_rocm_aiter_topk_softmax_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_biased_grouped_topk",
|
||||
op_func=_rocm_aiter_biased_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_grouped_topk",
|
||||
op_func=_rocm_aiter_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=_rocm_aiter_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=_rocm_aiter_gemm_w8a8_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_w8a8_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=_rocm_aiter_rms_norm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_rms_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
@staticmethod
|
||||
def rms_norm2d_with_add(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
|
||||
x, residual, weight, variance_epsilon
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
@staticmethod
|
||||
def gemm_w8a8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def gemm_w8a8_blockscale(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
A, B, As, Bs, output_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
quant_method: int = 0,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
expert_mask,
|
||||
activation_method,
|
||||
quant_method,
|
||||
doweight_stage1,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def asm_moe_tkw1(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale,
|
||||
fc2_scale,
|
||||
fc1_smooth_scale,
|
||||
fc2_smooth_scale,
|
||||
a16,
|
||||
per_tensor_quant_scale,
|
||||
expert_mask,
|
||||
activation_method,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
torch.ops.vllm.rocm_aiter_topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@staticmethod
|
||||
def biased_grouped_topk(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def grouped_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
scoring_func,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp4_gemm_dynamic_qaunt(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype | None = torch.bfloat16,
|
||||
x_scales: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
if x_scales is None:
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
y = torch.empty(
|
||||
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||
)
|
||||
|
||||
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def triton_rotary_embed(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
):
|
||||
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
|
||||
|
||||
num_tokens = positions.numel()
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
rotate_style = 0 if is_neox_style else 1
|
||||
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
query_ = query[..., :rotary_dim]
|
||||
key_ = key[..., :rotary_dim]
|
||||
positions = positions.view(*query.shape[:1])
|
||||
rope_cached_thd_positions_2c_fwd_inplace(
|
||||
positions,
|
||||
sin,
|
||||
cos,
|
||||
query_,
|
||||
key_,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
is_nope_first=False,
|
||||
)
|
||||
query = query.view(query_shape)
|
||||
key = key.view(key_shape)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp8_bmm(
|
||||
X: torch.Tensor,
|
||||
WQ: torch.Tensor,
|
||||
w_scale: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
bias: torch.Tensor | None = None,
|
||||
dtype: torch.dtype | None = torch.bfloat16,
|
||||
splitK: int | None = None,
|
||||
YQ: torch.Tensor | None = None,
|
||||
transpose_bm: bool | None = False,
|
||||
config: dict | None = None,
|
||||
) -> torch.Tensor:
|
||||
# ruff: noqa: E501 # isort: skip
|
||||
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
|
||||
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
|
||||
)
|
||||
|
||||
return aiter_triton_fp8_bmm(
|
||||
X,
|
||||
WQ,
|
||||
w_scale,
|
||||
group_size=group_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
splitK=splitK,
|
||||
YQ=YQ,
|
||||
transpose_bm=transpose_bm,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def triton_gemm_a8w8_blockscale(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def per_1x128_fp8_quant(
|
||||
input_2d: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Only applies quantization method for fp8 data type only."""
|
||||
from aiter import QuantType, dtypes, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
|
||||
|
||||
@staticmethod
|
||||
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
|
||||
return (n, k) in [
|
||||
(1024, 8192),
|
||||
(2112, 7168),
|
||||
(3072, 1536),
|
||||
(32768, 8192),
|
||||
(4096, 7168),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2048),
|
||||
(7168, 256),
|
||||
(8192, 1024),
|
||||
(8192, 32768),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weight(
|
||||
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
return shuffle_weight(tensor, layout=layout)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weights(
|
||||
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Applies shuffle_weight function from AITER to each
|
||||
input tensor and returns them.
|
||||
|
||||
Rearranges (shuffles) the input tensor/s
|
||||
into a specified block layout for optimized computation.
|
||||
|
||||
Args:
|
||||
*tensors: Variable number of torch.Tensor objects.
|
||||
layout: A pair of integers specifying the block sizes used to divide
|
||||
the tensors during shuffling. Default is (16, 16).
|
||||
|
||||
Returns:
|
||||
A Tuple of shuffled tensors.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||
|
||||
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
@ -1,105 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(
|
||||
max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
paged_kv_indices = torch.zeros(
|
||||
max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
|
||||
)
|
||||
paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
|
||||
paged_kv_last_page_lens = torch.full(
|
||||
(max_batch_size,), block_size, dtype=torch.int32
|
||||
)
|
||||
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
|
||||
def mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
|
||||
def mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
if is_torch_equal_or_newer("2.7.0"):
|
||||
tags = ()
|
||||
else:
|
||||
tags = ((torch.Tag.needs_fixed_stride_order,),)
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=mla_decode_fwd_fake,
|
||||
tags=tags,
|
||||
)
|
||||
@ -109,7 +109,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
VLLM_ROCM_USE_AITER_MHA: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||
VLLM_ROCM_USE_TRITON_ROPE: bool = False
|
||||
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
|
||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
|
||||
@ -926,8 +926,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
),
|
||||
# Whether to use aiter rope.
|
||||
# By default is disabled.
|
||||
"VLLM_ROCM_USE_TRITON_ROPE": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1")
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use aiter triton fp8 bmm kernel
|
||||
# By default is enabled.
|
||||
@ -1589,7 +1589,7 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_USE_AITER_MLA",
|
||||
"VLLM_ROCM_USE_AITER_MHA",
|
||||
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
|
||||
"VLLM_ROCM_USE_TRITON_ROPE",
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE",
|
||||
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
|
||||
"VLLM_ROCM_USE_AITER_TRITON_GEMM",
|
||||
|
||||
@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
|
||||
|
||||
return rocm_aiter_topk_softmax
|
||||
def dispatch_topk_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
@ -1121,7 +1120,7 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.distributed import (
|
||||
@ -41,8 +42,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
is_rocm_aiter_fusion_shared_expert_enabled,
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -92,13 +91,11 @@ else:
|
||||
return topk_ids
|
||||
|
||||
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk as grouped_topk_aiter,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||
|
||||
@ -620,13 +618,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
# Padding the weight for better performance on ROCm
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
# Lazy import to avoid importing triton.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
@ -1002,6 +996,7 @@ def determine_expert_map(
|
||||
global_num_experts: int,
|
||||
expert_placement_strategy: ExpertPlacementStrategy = "linear",
|
||||
num_fused_shared_experts: int = 0,
|
||||
return_expert_mask: bool = False,
|
||||
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Calculates how many experts should be assigned to each rank for EP and
|
||||
@ -1064,7 +1059,7 @@ def determine_expert_map(
|
||||
)
|
||||
|
||||
expert_mask = None
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if return_expert_mask:
|
||||
expert_mask = torch.ones(
|
||||
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
|
||||
)
|
||||
@ -1292,14 +1287,18 @@ class FusedMoE(CustomOp):
|
||||
self.logical_replica_count: torch.Tensor | None = None
|
||||
|
||||
# ROCm aiter shared experts fusion
|
||||
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
self.aiter_fmoe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
|
||||
self.num_fused_shared_experts = (
|
||||
n_shared_experts
|
||||
if n_shared_experts is not None
|
||||
and is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
|
||||
else 0
|
||||
)
|
||||
if (
|
||||
not is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
not self.aiter_fmoe_shared_expert_enabled
|
||||
and self.num_fused_shared_experts != 0
|
||||
):
|
||||
raise ValueError(
|
||||
@ -1346,6 +1345,7 @@ class FusedMoE(CustomOp):
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
@ -1570,13 +1570,16 @@ class FusedMoE(CustomOp):
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
self.register_buffer("expert_mask", expert_mask)
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
|
||||
)
|
||||
if self.aiter_fmoe_shared_expert_enabled:
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=get_current_vllm_config(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
)
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
@ -1753,20 +1756,19 @@ class FusedMoE(CustomOp):
|
||||
def _init_aiter_shared_experts_topK_buffer(
|
||||
self, vllm_config: VllmConfig, dp_size: int
|
||||
):
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled():
|
||||
if self.num_fused_shared_experts > 0:
|
||||
init_aiter_topK_meta_data(
|
||||
n_routed_experts=self.global_num_experts,
|
||||
n_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=self.top_k,
|
||||
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
|
||||
tp_size=self.ep_size if self.use_ep else self.tp_size,
|
||||
shared_experts_score=1.0,
|
||||
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
|
||||
* dp_size,
|
||||
is_EP=self.use_ep,
|
||||
)
|
||||
self.local_num_experts += self.num_fused_shared_experts
|
||||
if self.num_fused_shared_experts > 0:
|
||||
init_aiter_topK_meta_data(
|
||||
n_routed_experts=self.global_num_experts,
|
||||
n_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=self.top_k,
|
||||
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
|
||||
tp_size=self.ep_size if self.use_ep else self.tp_size,
|
||||
shared_experts_score=1.0,
|
||||
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
|
||||
* dp_size,
|
||||
is_EP=self.use_ep,
|
||||
)
|
||||
self.local_num_experts += self.num_fused_shared_experts
|
||||
|
||||
@overload
|
||||
def weight_loader(
|
||||
@ -2208,15 +2210,16 @@ class FusedMoE(CustomOp):
|
||||
elif use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if not is_rocm_aiter_fusion_shared_expert_enabled():
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert num_fused_shared_experts == 0
|
||||
grouped_topk_impl = partial(
|
||||
grouped_topk_aiter,
|
||||
rocm_aiter_grouped_topk,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
grouped_topk_impl = grouped_topk
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@ -2448,7 +2451,7 @@ class FusedMoE(CustomOp):
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
if not self.rocm_aiter_fmoe_enabled
|
||||
else self.expert_mask,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
@ -2612,7 +2615,7 @@ class FusedMoE(CustomOp):
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
if not self.rocm_aiter_fmoe_enabled
|
||||
else self.expert_mask,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import IntEnum
|
||||
from functools import cache, lru_cache
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
class QuantMethod(IntEnum):
|
||||
@ -37,27 +35,6 @@ class ActivationMethod(IntEnum):
|
||||
GELU = 1
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_moe_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_MOE
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def use_mxfp4_aiter_moe() -> bool:
|
||||
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
|
||||
return (
|
||||
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
|
||||
)
|
||||
|
||||
|
||||
aiter_topK_meta_data = None
|
||||
|
||||
|
||||
@ -114,250 +91,6 @@ def init_aiter_topK_meta_data(
|
||||
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
|
||||
|
||||
|
||||
def rocm_aiter_asm_moe_tkw1_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
|
||||
return asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=fc1_scale,
|
||||
fc2_scale=fc2_scale,
|
||||
fc1_smooth_scale=fc1_smooth_scale,
|
||||
fc2_smooth_scale=fc2_smooth_scale,
|
||||
a16=a16,
|
||||
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||
expert_mask=expert_mask,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_asm_moe_tkw1_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def rocm_aiter_topk_softmax_impl(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
from aiter import topk_softmax
|
||||
|
||||
topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_topk_softmax_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_biased_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
from aiter import biased_grouped_topk
|
||||
|
||||
biased_grouped_topk(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_biased_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
from aiter import grouped_topk
|
||||
|
||||
grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
scoring_func,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_fused_moe_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
quant_method: int = QuantMethod.NO.value,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
quant_type = QuantType(quant_method)
|
||||
|
||||
return fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
expert_mask,
|
||||
activation,
|
||||
quant_type,
|
||||
doweight_stage1,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_fused_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
quant_method: int = QuantMethod.NO.value,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_moe",
|
||||
op_func=rocm_aiter_fused_moe_impl,
|
||||
fake_impl=rocm_aiter_fused_moe_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_topk_softmax",
|
||||
op_func=rocm_aiter_topk_softmax_impl,
|
||||
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||
fake_impl=rocm_aiter_topk_softmax_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_biased_grouped_topk",
|
||||
op_func=rocm_aiter_biased_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_biased_grouped_topk_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_grouped_topk",
|
||||
op_func=rocm_aiter_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_grouped_topk_fake,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
token = hidden_states.shape[0]
|
||||
device = hidden_states.device
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
|
||||
if (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
and num_fused_shared_experts > 0
|
||||
):
|
||||
assert aiter_topK_meta_data is not None, (
|
||||
"AITER topK meta data is not initialized. "
|
||||
"Please ensure that init_aiter_topK_meta_data "
|
||||
@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk(
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
rocm_aiter_ops.biased_grouped_topk(
|
||||
gating_output,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
topk_weights,
|
||||
@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk(
|
||||
)
|
||||
else:
|
||||
assert scoring_func == "softmax" or scoring_func == "sigmoid"
|
||||
torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
rocm_aiter_ops.grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
|
||||
if (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
and num_fused_shared_experts > 0
|
||||
):
|
||||
return total_topk_weights, total_topk_ids
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@ -464,7 +203,7 @@ def rocm_aiter_fused_experts(
|
||||
"Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
)
|
||||
|
||||
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
||||
return rocm_aiter_ops.asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@ -482,7 +221,9 @@ def rocm_aiter_fused_experts(
|
||||
|
||||
else:
|
||||
quant_method = QuantMethod.NO.value
|
||||
|
||||
# quark moe for mxfp4 w_dtype
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
quant_method = QuantMethod.BLOCK_1X32.value
|
||||
# w8a8 block-scaled
|
||||
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
|
||||
assert not apply_router_weight_on_input, (
|
||||
@ -507,7 +248,7 @@ def rocm_aiter_fused_experts(
|
||||
"Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
)
|
||||
|
||||
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||
return rocm_aiter_ops.fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@ -522,39 +263,3 @@ def rocm_aiter_fused_experts(
|
||||
a2_scale=quant_config.a2_scale,
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
torch.ops.vllm.rocm_aiter_topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def shuffle_weights(
|
||||
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Applies shuffle_weight function from AITER to each
|
||||
input tensor and returns them.
|
||||
|
||||
Rearranges (shuffles) the input tensor/s
|
||||
into a specified block layout for optimized computation.
|
||||
|
||||
Args:
|
||||
*tensors: Variable number of torch.Tensor objects.
|
||||
layout: A pair of integers specifying the block sizes used to divide
|
||||
the tensors during shuffling. Default is (16, 16).
|
||||
|
||||
Returns:
|
||||
A Tuple of shuffled tensors.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||
|
||||
@ -6,18 +6,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
def rms_norm(
|
||||
@ -58,80 +53,34 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
def poly_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
return x.reshape(x_original_shape)
|
||||
|
||||
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
output = torch.empty_like(x)
|
||||
rocm_aiter.rmsnorm2d_fwd_with_add(
|
||||
output, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual_out, # residual output
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
return out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=rocm_aiter_rms_norm_impl,
|
||||
fake_impl=rocm_aiter_rms_norm_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
|
||||
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
|
||||
def dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
|
||||
):
|
||||
use_aiter = use_aiter and dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
if use_aiter and with_fused_add:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
return rocm_aiter_ops.rms_norm2d_with_add
|
||||
if use_aiter:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm
|
||||
return rocm_aiter_ops.rms_norm
|
||||
|
||||
# fall back to CUDA implementation
|
||||
if with_fused_add:
|
||||
@ -169,11 +118,14 @@ class RMSNorm(CustomOp):
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=False, dtype=weight_dtype
|
||||
with_fused_add=False,
|
||||
dtype=weight_dtype,
|
||||
use_aiter=aiter_rmsnorm_enabled,
|
||||
)
|
||||
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=True, dtype=weight_dtype
|
||||
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -12,6 +12,7 @@ from compressed_tensors.quantization import ActivationOrdering, QuantizationStra
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
@ -582,11 +583,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# cutlass path
|
||||
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
|
||||
@ -829,12 +827,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
|
||||
@ -7,12 +7,12 @@ import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
)
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
assert not self.is_static_input_scheme
|
||||
|
||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
@ -369,7 +369,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
@ -869,12 +869,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
@ -916,7 +912,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
@ -962,7 +958,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
@ -1042,7 +1038,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
start += shard_size
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
|
||||
@ -4,54 +4,14 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import gemm_a8w8_CK
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=rocm_aiter_gemm_w8a8_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_fake,
|
||||
)
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -75,7 +35,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
+ "installed on ROCm.",
|
||||
)
|
||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
|
||||
if not (rocm_aiter_ops.is_linear_enabled()):
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel is disabled. "
|
||||
@ -157,6 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(
|
||||
x_q, w_q.t(), x_s, w_s, bias, out_dtype
|
||||
)
|
||||
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
@ -21,10 +22,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
use_mxfp4_aiter_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
)
|
||||
@ -122,7 +119,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -309,12 +306,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
@ -470,13 +463,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"not implemented. Please open an issue."
|
||||
)
|
||||
|
||||
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
self.emulate = not current_platform.supports_mx() or not (
|
||||
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||
)
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
|
||||
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
|
||||
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
@ -656,28 +651,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
if not self.emulate:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
aiter_acts = {
|
||||
ActivationType.No.name.lower(): ActivationType.No,
|
||||
ActivationType.Silu.name.lower(): ActivationType.Silu,
|
||||
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
|
||||
}
|
||||
assert activation in aiter_acts, (
|
||||
f"Aiter CK fp4 MoE doesn't support activation {activation}"
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
out = fused_moe(
|
||||
|
||||
out = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=aiter_acts[activation],
|
||||
doweight_stage1=False,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -31,6 +31,13 @@ from .quark_scheme import QuarkScheme
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: move registration of custom op to aiter_ops.py
|
||||
# `from vllm._aiter_ops import rocm_aiter_ops`
|
||||
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
|
||||
# for envs checks which does not require @cache anymore.
|
||||
# triton kernel is torch compile compatible.
|
||||
# does not require direct registeration.
|
||||
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return (
|
||||
|
||||
@ -12,6 +12,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@ -68,78 +69,6 @@ def cutlass_scaled_mm(
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
def is_aiter_triton_kernel_tuned(n, k):
|
||||
return (n, k) in [
|
||||
(1024, 8192),
|
||||
(2112, 7168),
|
||||
(3072, 1536),
|
||||
(32768, 8192),
|
||||
(4096, 7168),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2048),
|
||||
(7168, 256),
|
||||
(8192, 1024),
|
||||
(8192, 32768),
|
||||
]
|
||||
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
group_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
else:
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
import aiter as rocm_aiter
|
||||
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
q_input, input_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
|
||||
)
|
||||
|
||||
return gemm_a8w8_blockscale(
|
||||
q_input, weight, input_scale, weight_scale, dtype=output_dtype
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = input_2d.shape[0]
|
||||
n = weight.shape[0]
|
||||
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
)
|
||||
|
||||
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
input_2d,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
self.act_quant_group_shape.col,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
if (
|
||||
not current_platform.is_fp8_fnuz()
|
||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||
):
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
self.act_quant_group_shape.col,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
else:
|
||||
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
|
||||
return rocm_aiter_ops.gemm_w8a8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
|
||||
def check_aiter_fp8_linear_support() -> bool:
|
||||
"""AITER is only supported on ROCm for MI3XX"""
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
)
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
|
||||
@ -472,7 +472,7 @@ class Fp8LinearOp:
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
|
||||
@ -4,13 +4,10 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
from .rocm_aiter_rope_ops import (
|
||||
is_rocm_triton_rotary_embedding_enabled,
|
||||
rocm_aiter_rotary_emb,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@ -48,8 +45,8 @@ class RotaryEmbeddingBase(CustomOp):
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.is_rocm_triton_rotary_embedding_enabled = (
|
||||
is_rocm_triton_rotary_embedding_enabled()
|
||||
self.is_rocm_triton_rotary_embed_enabled = (
|
||||
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
@ -169,9 +166,9 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.is_rocm_triton_rotary_embedding_enabled:
|
||||
if self.is_rocm_triton_rotary_embed_enabled:
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
rocm_aiter_rotary_emb(
|
||||
rocm_aiter_ops.triton_rotary_embed(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
|
||||
@ -146,6 +146,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
@ -1,94 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_triton_rotary_embedding_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_TRITON_ROPE
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
import aiter.ops.triton.rope as ops
|
||||
|
||||
ops.rope_cached_thd_positions_2c_fwd_inplace(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if is_rocm_triton_rotary_embedding_enabled():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
|
||||
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
|
||||
mutates_args=["key", "query"],
|
||||
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
):
|
||||
num_tokens = positions.numel()
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
rotate_style = 0 if is_neox_style else 1
|
||||
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
query_ = query[..., :rotary_dim]
|
||||
key_ = key[..., :rotary_dim]
|
||||
positions = positions.view(*query.shape[:1])
|
||||
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
|
||||
positions,
|
||||
sin,
|
||||
cos,
|
||||
query_,
|
||||
key_,
|
||||
rotate_style,
|
||||
False,
|
||||
)
|
||||
query = query.view(query_shape)
|
||||
key = key.view(key_shape)
|
||||
@ -33,6 +33,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
@ -50,10 +51,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled,
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@ -294,10 +291,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
if (
|
||||
config.n_shared_experts is None
|
||||
or is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
):
|
||||
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
@ -330,14 +325,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
# aiter applies routed_scaling_factor internally
|
||||
routed_scaling_factor=1.0
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
if not self.is_rocm_aiter_moe_enabled
|
||||
else self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
n_shared_experts=config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
else None,
|
||||
)
|
||||
|
||||
@ -371,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
if not is_rocm_aiter_moe_enabled():
|
||||
if not self.is_rocm_aiter_moe_enabled:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
@ -1428,6 +1423,9 @@ class DeepseekV2ForCausalLM(
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rocm_aiter_moe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
@ -1456,7 +1454,7 @@ class DeepseekV2ForCausalLM(
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if rocm_aiter_moe_shared_expert_enabled
|
||||
else 0
|
||||
),
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
@ -1472,9 +1470,8 @@ class DeepseekV2ForCausalLM(
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fuse_shared_experts_layer = (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
and ("mlp.shared_experts" in name)
|
||||
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
|
||||
"mlp.shared_experts" in name
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention(
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> bool:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
|
||||
and not (rocm_aiter_ops.is_pa_attn_enabled())
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
@ -202,12 +204,15 @@ class RocmPlatform(Platform):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return _Backend.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
@ -228,19 +233,23 @@ class RocmPlatform(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
|
||||
if use_mla:
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled,
|
||||
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend."
|
||||
)
|
||||
|
||||
if use_mla:
|
||||
if selected_backend is None:
|
||||
selected_backend = (
|
||||
_Backend.ROCM_AITER_MLA
|
||||
if is_aiter_mla_enabled() or block_size == 1
|
||||
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA
|
||||
)
|
||||
|
||||
@ -265,12 +274,12 @@ class RocmPlatform(Platform):
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
|
||||
rocm_aiter_ops.is_mha_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_FA:
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
rocm_aiter_ops.is_triton_unified_attn_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return (
|
||||
|
||||
@ -198,6 +198,7 @@ from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
@ -270,28 +271,15 @@ except ImportError:
|
||||
flashinfer_available = False
|
||||
|
||||
|
||||
def is_rocm_aiter_fp8bmm_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
|
||||
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501
|
||||
)
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1109,6 +1097,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
@ -1158,7 +1147,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
@ -1187,7 +1176,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
|
||||
@ -1196,7 +1185,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
@ -1208,10 +1197,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = aiter_triton_fp8_bmm(
|
||||
x = rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
@ -1571,7 +1559,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
@ -1600,7 +1588,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
|
||||
@ -1609,7 +1597,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
@ -1958,7 +1946,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
|
||||
# Pads the head_dim if necessary (for the underlying kernel)
|
||||
if self.q_pad_num_heads is not None:
|
||||
B, N, L = decode_q_pe.shape
|
||||
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
|
||||
@ -1966,9 +1953,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
decode_pe_padded.copy_(decode_q_pe)
|
||||
decode_q_pe = decode_pe_padded
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
decode_ql_nope = aiter_triton_fp8_bmm(
|
||||
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
||||
decode_q_nope,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
|
||||
@ -6,9 +6,8 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
@ -22,10 +21,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -284,7 +279,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
aiter_mla_decode_fwd(
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user