diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 010817e79a936..c32bf04c71c1f 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -642,48 +642,130 @@ _OPS_REGISTERED = False class rocm_aiter_ops: + """ROCm AITER operations wrapper for AMD GPU acceleration in vLLM. + + This class centralizes the import and registration of AITER ops, + and provides a unified interface for checking if AITER is enabled. + Operations are only available on supported gfx9 + architectures when aiter is installed. + + The class uses environment variables to control which features are enabled, + allowing fine-grained control over which AITER optimizations are used. + + Environment Variables: + VLLM_ROCM_USE_AITER: Main toggle for all AITER operations. + VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops. + VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations. + VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops. + VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops. + VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen. + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention. + VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply. + VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM. + VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings. + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion. + VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM. + + Note: + The environment variables are assigned when the module is imported, + so you can't change the environment variables after the module is imported. + This is done out of performance consideration. Accessing environment variables + is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067 + so we don't want to do it repeatedly, especially in the hot path (the forward pass). + You can call the refresh_env_variables() function to reload the env variables + after monkey patching the env variables in the unit test. + + Check Functions: + All check functions (is_*_enabled) are decorated with @if_aiter_supported, + which verifies: (1) platform is ROCm, (2) device arch is gfx9, and + (3) aiter library is installed. The check function then also verifies + the corresponding environment variable is enabled. + i.e. ___ + is_enabled() == current_platform.is_rocm() and | checked by + current_platform.is_on_gfx9() and | @if_aiter_supported + IS_AITER_FOUND and _______________| + cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` + + Example: + from vllm._aiter_ops import rocm_aiter_ops + + # Check if aiter is enabled before using operations + if rocm_aiter_ops.is_enabled(): + result = rocm_aiter_ops.rms_norm(x, weight, epsilon) + + Operations: + - RMS normalization: rms_norm, rms_norm2d_with_add + - GEMM operations: gemm_a8w8, gemm_a8w8_blockscale + - Fused MoE: fused_moe, asm_moe_tkw1 + - Routing: topk_softmax, biased_grouped_topk, grouped_topk + - MLA decode: mla_decode_fwd + - Quantization: per_tensor_quant, per_token_quant, group_fp8_quant + - Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale + """ + + # Check if the env variable is set _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA - _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @classmethod + def refresh_env_variables(cls): + """ + Since the environment variables are assigned when the module is imported, + This is a helper function to reload all the env variables from + the environment variables. + for example, after monkey patching the env variables in the unit test, + you can call this function to reload the env variables. + """ + cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER + cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA + cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE + cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @classmethod @if_aiter_supported def is_enabled(cls) -> bool: - """Verifies device specs and availability of aiter main env variable.""" return cls._AITER_ENABLED @classmethod @if_aiter_supported def is_linear_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod @if_aiter_supported def is_linear_fp8_enaled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls.is_linear_enabled() @classmethod @if_aiter_supported def is_rmsnorm_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod @if_aiter_supported def is_fused_moe_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod @@ -694,25 +776,16 @@ class rocm_aiter_ops: @classmethod @if_aiter_supported def is_mla_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod @if_aiter_supported def is_mha_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._MHA_ENABLED - @classmethod - @if_aiter_supported - def is_pa_attn_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" - return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED - @classmethod @if_aiter_supported def is_triton_unified_attn_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e469a928da229..c237f7cf887c1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -124,8 +124,6 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: - from vllm._aiter_ops import rocm_aiter_ops - GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (rocm_aiter_ops.is_pa_attn_enabled()) and sinks is None )