# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable import torch import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer _FP8_DTYPE = current_platform.fp8_dtype() def is_aiter_found() -> bool: from importlib.util import find_spec return find_spec("aiter") is not None # `find_spec` is not torch.compile compatible. # In cases where aiter availability might have # been checked in forward passes that are torch compiled. # we keep this global outside to not cause torch compile breaks. IS_AITER_FOUND = is_aiter_found() def is_aiter_found_and_supported() -> bool: if current_platform.is_rocm() and IS_AITER_FOUND: from vllm.platforms.rocm import on_gfx9 return on_gfx9() return False def if_aiter_supported(func: Callable) -> Callable: """Decorator that only executes the function if ROCm AITER package is supported on gfx9 archs. """ @functools.wraps(func) def wrapper(*args, **kwargs): # checks the platform, device arch and aiter library existence. if is_aiter_found_and_supported(): return func(*args, **kwargs) return None return wrapper # Can't use dtypes.fp8 directly inside an op # because it returns wrong result on gfx942. # This is a workaround to get the correct FP8 dtype. # This might because that the get_gfx() is wrapped as a custom op. if is_aiter_found_and_supported(): from aiter import dtypes AITER_FP8_DTYPE = dtypes.fp8 def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, expert_mask: torch.Tensor | None = None, activation_method: int = 0, quant_method: int = 0, doweight_stage1: bool = False, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe activation = ActivationType(activation_method) quant_type = QuantType(quant_method) return fused_moe( hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, activation, quant_type, doweight_stage1, w1_scale, w2_scale, a1_scale, a2_scale, ) def _rocm_aiter_fused_moe_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, expert_mask: torch.Tensor | None = None, activation_method: int = 0, quant_method: int = 0, doweight_stage1: bool = False, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) def _rocm_aiter_asm_moe_tkw1_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: torch.Tensor | None = None, fc2_scale: torch.Tensor | None = None, fc1_smooth_scale: torch.Tensor | None = None, fc2_smooth_scale: torch.Tensor | None = None, a16: bool = False, per_tensor_quant_scale: torch.Tensor | None = None, expert_mask: torch.Tensor | None = None, activation_method: int = 0, ) -> torch.Tensor: from aiter import ActivationType from aiter.fused_moe_bf16_asm import asm_moe_tkw1 activation = ActivationType(activation_method) return asm_moe_tkw1( hidden_states, w1, w2, topk_weights, topk_ids, fc1_scale=fc1_scale, fc2_scale=fc2_scale, fc1_smooth_scale=fc1_smooth_scale, fc2_smooth_scale=fc2_smooth_scale, a16=a16, per_tensor_quant_scale=per_tensor_quant_scale, expert_mask=expert_mask, activation=activation, ) def _rocm_aiter_asm_moe_tkw1_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: torch.Tensor | None = None, fc2_scale: torch.Tensor | None = None, fc1_smooth_scale: torch.Tensor | None = None, fc2_smooth_scale: torch.Tensor | None = None, a16: bool = False, per_tensor_quant_scale: torch.Tensor | None = None, expert_mask: torch.Tensor | None = None, activation_method: int = 0, ) -> torch.Tensor: return torch.empty_like(hidden_states) def _rocm_aiter_topk_softmax_impl( topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, ) -> None: from aiter import topk_softmax topk_softmax( topk_weights, topk_indices, token_expert_indices, gating_output, renormalize ) def _rocm_aiter_topk_softmax_fake( topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, ) -> None: pass def _rocm_aiter_biased_grouped_topk_impl( gating_output: torch.Tensor, correction_bias: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_expert_group: int, topk_group: int, need_renorm: bool, routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: from aiter import biased_grouped_topk biased_grouped_topk( gating_output, correction_bias, topk_weights, topk_ids, num_expert_group, topk_group, need_renorm, routed_scaling_factor, ) def _rocm_aiter_biased_grouped_topk_fake( gating_output: torch.Tensor, correction_bias: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_expert_group: int, topk_group: int, need_renorm: bool, routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: pass def _rocm_aiter_grouped_topk_impl( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_expert_group: int, topk_group: int, need_renorm: bool, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: is_softmax = scoring_func == "softmax" from aiter import grouped_topk grouped_topk( gating_output, topk_weights, topk_ids, num_expert_group, topk_group, need_renorm, is_softmax, routed_scaling_factor, ) def _rocm_aiter_grouped_topk_fake( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_expert_group: int, topk_group: int, need_renorm: bool, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: pass # Cache whether aiter supports FP8 MLA parameters _AITER_MLA_SUPPORTS_FP8: bool | None = None def _check_aiter_mla_fp8_support() -> bool: """Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters.""" global _AITER_MLA_SUPPORTS_FP8 if _AITER_MLA_SUPPORTS_FP8 is None: try: import inspect from aiter.mla import mla_decode_fwd sig = inspect.signature(mla_decode_fwd) _AITER_MLA_SUPPORTS_FP8 = ( "q_scale" in sig.parameters and "kv_scale" in sig.parameters ) except Exception: _AITER_MLA_SUPPORTS_FP8 = False return _AITER_MLA_SUPPORTS_FP8 def _rocm_aiter_mla_decode_fwd_impl( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, qo_indptr: torch.Tensor, max_seqlen_qo: int, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd kwargs = { "sm_scale": sm_scale, "logit_cap": logit_cap, } # Only pass q_scale and kv_scale if the aiter library supports them if _check_aiter_mla_fp8_support(): kwargs["q_scale"] = q_scale kwargs["kv_scale"] = kv_scale mla_decode_fwd( q, kv_buffer.view(-1, 1, 1, q.shape[-1]), o, qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, max_seqlen_qo, **kwargs, ) def _rocm_aiter_mla_decode_fwd_fake( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, qo_indptr: torch.Tensor, max_seqlen_qo: int, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, ) -> None: pass def _rocm_aiter_gemm_a8w8_impl( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, bias: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: from aiter import gemm_a8w8_CK # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) def _rocm_aiter_gemm_a8w8_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, bias: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: m = A.shape[0] n = B.shape[0] Y = torch.empty(m, n, dtype=output_dtype, device=A.device) return Y def _rocm_aiter_gemm_a8w8_blockscale_impl( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: from aiter import gemm_a8w8_blockscale return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) def _rocm_aiter_gemm_a8w8_blockscale_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: m = A.shape[0] n = B.shape[0] Y = torch.empty(m, n, dtype=output_dtype, device=A.device) return Y def _rocm_aiter_rms_norm_impl( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: from aiter import rms_norm if x.dim() > 2: x_original_shape = x.shape x = x.reshape(-1, x_original_shape[-1]) x = rms_norm(x, weight, variance_epsilon) return x.reshape(x_original_shape) return rms_norm(x, weight, variance_epsilon) def _rocm_aiter_rms_norm_fake( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: return torch.empty_like(x) def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: from aiter import rmsnorm2d_fwd_with_add residual_out = torch.empty_like(residual) output = torch.empty_like(x) rmsnorm2d_fwd_with_add( output, # output x, # input residual, # residual input residual_out, # residual output weight, variance_epsilon, ) return output, residual_out def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(x), torch.empty_like(residual) def _rocm_aiter_per_tensor_quant_impl( x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: from aiter.ops.quant import per_tensor_quant_hip return per_tensor_quant_hip(x, scale, quant_dtype) def _rocm_aiter_per_tensor_quant_fake( x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(x, dtype=quant_dtype), torch.empty( 1, dtype=torch.float32, device=x.device ) def _rocm_aiter_per_token_quant_impl( x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: from aiter.ops.quant import dynamic_per_token_scaled_quant assert quant_dtype in [torch.int8, _FP8_DTYPE] out_shape = x.shape out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device) if scale is None: scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) dynamic_per_token_scaled_quant( out, x, scale, scale_ub=None, shuffle_scale=False, num_rows=None, num_rows_factor=1, ) return out, scale def _rocm_aiter_per_token_quant_fake( x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: out_shape = x.shape return ( torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device), torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), ) def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( x, weight, variance_epsilon, None, None, None, group_size=group_size, dtype_quant=AITER_FP8_DTYPE, res1=residual, ) return (x_quant, x_quant_scales, res) def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: M, N = x.shape scale_shape = (M, (N + group_size - 1) // group_size) return ( torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), torch.empty(scale_shape, dtype=torch.float32, device=x.device), torch.empty_like(residual, device=residual.device), ) def _rocm_aiter_rmsnorm_fp8_group_quant_impl( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( x, weight, variance_epsilon, None, None, None, group_size=group_size, dtype_quant=AITER_FP8_DTYPE, res1=None, ) return (x_quant, x_quant_scales) def _rocm_aiter_rmsnorm_fp8_group_quant_fake( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: M, N = x.shape scale_shape = (M, (N + group_size - 1) // group_size) return ( torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), torch.empty(scale_shape, dtype=torch.float32, device=x.device), ) def _rocm_aiter_group_fp8_quant_impl( x: torch.Tensor, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" from aiter import QuantType, get_hip_quant aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE) def _rocm_aiter_group_fp8_quant_fake( x: torch.Tensor, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: M, N = x.shape x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device) out_bs = torch.empty( ( M, (N + group_size - 1) // group_size, ), dtype=torch.float32, device=x.device, ) return x_fp8, out_bs def _rocm_aiter_act_mul_and_fp8_group_quant_impl( x: torch.Tensor, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: from aiter.ops.triton.activation import act_mul_and_fp8_group_quant return act_mul_and_fp8_group_quant( x, activation="silu", group_size=group_size, dtype_quant=AITER_FP8_DTYPE, ) def _rocm_aiter_act_mul_and_fp8_group_quant_fake( x: torch.Tensor, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: M, N = x.shape assert N % 2 == 0 N_half = N // 2 x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device) out_bs = torch.empty( ( M, (N_half + group_size - 1) // group_size, ), dtype=torch.float32, device=x.device, ) return x_fp8, out_bs # Global flag to ensure ops are registered only once _OPS_REGISTERED = False class rocm_aiter_ops: """ROCm AITER operations wrapper for AMD GPU acceleration in vLLM. This class centralizes the import and registration of AITER ops, and provides a unified interface for checking if AITER is enabled. Operations are only available on supported gfx9 architectures when aiter is installed. The class uses environment variables to control which features are enabled, allowing fine-grained control over which AITER optimizations are used. Environment Variables: VLLM_ROCM_USE_AITER: Main toggle for all AITER operations. VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops. VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations. VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops. VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops. VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen. VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention. VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply. VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM. VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings. VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion. VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM. Note: The environment variables are assigned when the module is imported, so you can't change the environment variables after the module is imported. This is done out of performance consideration. Accessing environment variables is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067 so we don't want to do it repeatedly, especially in the hot path (the forward pass). You can call the refresh_env_variables() function to reload the env variables after monkey patching the env variables in the unit test. Check Functions: All check functions (is_*_enabled) are decorated with @if_aiter_supported, which verifies: (1) platform is ROCm, (2) device arch is gfx9, and (3) aiter library is installed. The check function then also verifies the corresponding environment variable is enabled. i.e. ___ is_enabled() == current_platform.is_rocm() and | checked by current_platform.is_on_gfx9() and | @if_aiter_supported IS_AITER_FOUND and _______________| cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` Example: from vllm._aiter_ops import rocm_aiter_ops # Check if aiter is enabled before using operations if rocm_aiter_ops.is_enabled(): result = rocm_aiter_ops.rms_norm(x, weight, epsilon) Operations: - RMS normalization: rms_norm, rms_norm2d_with_add - GEMM operations: gemm_a8w8, gemm_a8w8_blockscale - Fused MoE: fused_moe, asm_moe_tkw1 - Routing: topk_softmax, biased_grouped_topk, grouped_topk - MLA decode: mla_decode_fwd - Quantization: per_tensor_quant, per_token_quant, group_fp8_quant - Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale """ # Check if the env variable is set _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM @classmethod def refresh_env_variables(cls): """ Since the environment variables are assigned when the module is imported, This is a helper function to reload all the env variables from the environment variables. for example, after monkey patching the env variables in the unit test, you can call this function to reload the env variables. """ cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM @classmethod @if_aiter_supported def is_enabled(cls) -> bool: return cls._AITER_ENABLED @classmethod @if_aiter_supported def is_linear_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod @if_aiter_supported def is_linear_fp8_enabled(cls) -> bool: return cls.is_linear_enabled() @classmethod @if_aiter_supported def is_rmsnorm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod @if_aiter_supported def is_fused_moe_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod @if_aiter_supported def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED @classmethod @if_aiter_supported def is_mla_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod @if_aiter_supported def is_mha_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MHA_ENABLED @classmethod @if_aiter_supported def is_triton_unified_attn_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod @if_aiter_supported def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED @classmethod @if_aiter_supported def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM @classmethod @if_aiter_supported def is_triton_rotary_embed_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED @classmethod @if_aiter_supported def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM @staticmethod @if_aiter_supported def register_ops_once() -> None: global _OPS_REGISTERED if not _OPS_REGISTERED: tags = ( tuple() if is_torch_equal_or_newer("2.7.0") else (torch.Tag.needs_fixed_stride_order,) ) # register all the custom ops here direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=_rocm_aiter_asm_moe_tkw1_impl, mutates_args=[], fake_impl=_rocm_aiter_asm_moe_tkw1_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_fused_moe", op_func=_rocm_aiter_fused_moe_impl, mutates_args=[], fake_impl=_rocm_aiter_fused_moe_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_topk_softmax", op_func=_rocm_aiter_topk_softmax_impl, mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], fake_impl=_rocm_aiter_topk_softmax_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_biased_grouped_topk", op_func=_rocm_aiter_biased_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=_rocm_aiter_biased_grouped_topk_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_grouped_topk", op_func=_rocm_aiter_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=_rocm_aiter_grouped_topk_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_mla_decode_fwd", op_func=_rocm_aiter_mla_decode_fwd_impl, mutates_args=["o"], fake_impl=_rocm_aiter_mla_decode_fwd_fake, tags=tags, ) direct_register_custom_op( op_name="rocm_aiter_gemm_a8w8", op_func=_rocm_aiter_gemm_a8w8_impl, mutates_args=[], fake_impl=_rocm_aiter_gemm_a8w8_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_gemm_a8w8_blockscale", op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, ) direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=_rocm_aiter_rms_norm_impl, fake_impl=_rocm_aiter_rms_norm_fake, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm_fp8_group_quant", op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl, fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake, ) direct_register_custom_op( op_name="rocm_aiter_act_mul_and_fp8_group_quant", op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, ) direct_register_custom_op( op_name="rocm_aiter_group_fp8_quant", op_func=_rocm_aiter_group_fp8_quant_impl, fake_impl=_rocm_aiter_group_fp8_quant_fake, ) direct_register_custom_op( op_name="rocm_aiter_per_tensor_quant", op_func=_rocm_aiter_per_tensor_quant_impl, mutates_args=[], fake_impl=_rocm_aiter_per_tensor_quant_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_per_token_quant", op_func=_rocm_aiter_per_token_quant_impl, mutates_args=["scale"], fake_impl=_rocm_aiter_per_token_quant_fake, dispatch_key=current_platform.dispatch_key, ) _OPS_REGISTERED = True @staticmethod def rms_norm2d_with_add( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add( x, residual, weight, variance_epsilon ) @staticmethod def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) @staticmethod def gemm_a8w8( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, bias: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype) @staticmethod def gemm_a8w8_blockscale( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale( A, B, As, Bs, output_dtype ) @staticmethod def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, expert_mask: torch.Tensor | None = None, activation_method: int = 0, quant_method: int = 0, doweight_stage1: bool = False, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, activation_method, quant_method, doweight_stage1, w1_scale, w2_scale, a1_scale, a2_scale, ) @staticmethod def asm_moe_tkw1( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: torch.Tensor | None = None, fc2_scale: torch.Tensor | None = None, fc1_smooth_scale: torch.Tensor | None = None, fc2_smooth_scale: torch.Tensor | None = None, a16: bool = False, per_tensor_quant_scale: torch.Tensor | None = None, expert_mask: torch.Tensor | None = None, activation_method: int = 0, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( hidden_states, w1, w2, topk_weights, topk_ids, fc1_scale, fc2_scale, fc1_smooth_scale, fc2_smooth_scale, a16, per_tensor_quant_scale, expert_mask, activation_method, ) @staticmethod def topk_softmax( topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, ) -> tuple[torch.Tensor, ...]: torch.ops.vllm.rocm_aiter_topk_softmax( topk_weights, topk_indices, token_expert_indices, gating_output, renormalize ) return topk_weights, topk_indices @staticmethod def biased_grouped_topk( gating_output: torch.Tensor, correction_bias: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_expert_group: int, topk_group: int, need_renorm: bool, routed_scaling_factor: float = 1.0, ) -> None: torch.ops.vllm.rocm_aiter_biased_grouped_topk( gating_output, correction_bias, topk_weights, topk_ids, num_expert_group, topk_group, need_renorm, routed_scaling_factor, ) @staticmethod def grouped_topk( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_expert_group: int, topk_group: int, need_renorm: bool, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, ) -> None: torch.ops.vllm.rocm_aiter_grouped_topk( gating_output, topk_weights, topk_ids, num_expert_group, topk_group, need_renorm, scoring_func, routed_scaling_factor, ) @staticmethod def mla_decode_fwd( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, sm_scale: float, qo_indptr: torch.Tensor, max_seqlen_qo: int, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_last_page_lens: torch.Tensor | None = None, logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, kv_buffer.view(-1, 1, 1, q.shape[-1]), o, qo_indptr, max_seqlen_qo, kv_indptr, kv_indices, kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap, q_scale=q_scale, kv_scale=kv_scale, ) @staticmethod def per_tensor_quant( x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale) @staticmethod def per_token_quant( x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale) @staticmethod def triton_fp4_gemm_dynamic_qaunt( x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, out_dtype: torch.dtype | None = torch.bfloat16, x_scales: torch.Tensor | None = None, ) -> torch.Tensor: from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant if x_scales is None: x_q, x_s = dynamic_mxfp4_quant(x) else: x_q = x x_s = x_scales y = torch.empty( x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype ) gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) return y @staticmethod def triton_rotary_embed( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, cos_sin_cache: torch.Tensor, head_size: int, rotary_dim: int, is_neox_style: bool, ): from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace num_tokens = positions.numel() cos, sin = cos_sin_cache.chunk(2, dim=-1) query_shape = query.shape key_shape = key.shape rotate_style = 0 if is_neox_style else 1 query = query.view(num_tokens, -1, head_size) key = key.view(num_tokens, -1, head_size) query_ = query[..., :rotary_dim] key_ = key[..., :rotary_dim] positions = positions.view(*query.shape[:1]) rope_cached_thd_positions_2c_fwd_inplace( positions, sin, cos, query_, key_, rotate_style, reuse_freqs_front_part=True, is_nope_first=False, ) query = query.view(query_shape) key = key.view(key_shape) @staticmethod def triton_fp8_bmm( X: torch.Tensor, WQ: torch.Tensor, w_scale: torch.Tensor, group_size: int = 128, bias: torch.Tensor | None = None, dtype: torch.dtype | None = torch.bfloat16, splitK: int | None = None, YQ: torch.Tensor | None = None, transpose_bm: bool | None = False, config: dict | None = None, ) -> torch.Tensor: # ruff: noqa: E501 # isort: skip from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, ) return aiter_triton_fp8_bmm( X, WQ, w_scale, group_size=group_size, bias=bias, dtype=dtype, splitK=splitK, YQ=YQ, transpose_bm=transpose_bm, config=config, ) @staticmethod def triton_gemm_a8w8_blockscale( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) @staticmethod def group_fp8_quant( input_2d: torch.Tensor, group_size: int = 128, ) -> tuple[torch.Tensor, ...]: assert group_size == 128, "Group size must be 128" return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size) @staticmethod def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: return (n, k) in [ (1024, 8192), (2112, 7168), (3072, 1536), (32768, 8192), (4096, 7168), (4608, 7168), (512, 7168), (7168, 2048), (7168, 256), (8192, 1024), (8192, 32768), ] @staticmethod def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool: return (n, k) in [ (8192, 4096), (1280, 8192), (16384, 53248), (106496, 16384), (57344, 8192), (8192, 2048), (2560, 8192), (10240, 8192), (16384, 16384), (8192, 28672), (28672, 8192), (18432, 16384), (8192, 1024), (7168, 8192), (5120, 8192), (8192, 8192), (8192, 7168), (14336, 8192), (8192, 14336), (8192, 3584), ] @staticmethod def shuffle_weight( self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) ) -> torch.Tensor: from aiter.ops.shuffle import shuffle_weight return shuffle_weight(tensor, layout=layout) @staticmethod def shuffle_weights( *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) ) -> tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. Rearranges (shuffles) the input tensor/s into a specified block layout for optimized computation. Args: *tensors: Variable number of torch.Tensor objects. layout: A pair of integers specifying the block sizes used to divide the tensors during shuffling. Default is (16, 16). Returns: A Tuple of shuffled tensors. """ from aiter.ops.shuffle import shuffle_weight return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) rocm_aiter_ops.register_ops_once()