mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 23:44:35 +08:00
[ROCm] [AITER] [DOC] Add usage description about check functions in _aiter_ops (#30586)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
6f15ac5de7
commit
d0fb572929
@ -642,48 +642,130 @@ _OPS_REGISTERED = False
|
||||
|
||||
|
||||
class rocm_aiter_ops:
|
||||
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
|
||||
|
||||
This class centralizes the import and registration of AITER ops,
|
||||
and provides a unified interface for checking if AITER is enabled.
|
||||
Operations are only available on supported gfx9
|
||||
architectures when aiter is installed.
|
||||
|
||||
The class uses environment variables to control which features are enabled,
|
||||
allowing fine-grained control over which AITER optimizations are used.
|
||||
|
||||
Environment Variables:
|
||||
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
|
||||
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
|
||||
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
|
||||
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
|
||||
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
|
||||
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
|
||||
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
|
||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
|
||||
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
|
||||
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
|
||||
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
|
||||
|
||||
Note:
|
||||
The environment variables are assigned when the module is imported,
|
||||
so you can't change the environment variables after the module is imported.
|
||||
This is done out of performance consideration. Accessing environment variables
|
||||
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
|
||||
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
|
||||
You can call the refresh_env_variables() function to reload the env variables
|
||||
after monkey patching the env variables in the unit test.
|
||||
|
||||
Check Functions:
|
||||
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
|
||||
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
|
||||
(3) aiter library is installed. The check function then also verifies
|
||||
the corresponding environment variable is enabled.
|
||||
i.e. ___
|
||||
is_enabled() == current_platform.is_rocm() and | checked by
|
||||
current_platform.is_on_gfx9() and | @if_aiter_supported
|
||||
IS_AITER_FOUND and _______________|
|
||||
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
|
||||
|
||||
Example:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
# Check if aiter is enabled before using operations
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
|
||||
|
||||
Operations:
|
||||
- RMS normalization: rms_norm, rms_norm2d_with_add
|
||||
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
|
||||
- Fused MoE: fused_moe, asm_moe_tkw1
|
||||
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
|
||||
- MLA decode: mla_decode_fwd
|
||||
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
|
||||
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
|
||||
"""
|
||||
|
||||
# Check if the env variable is set
|
||||
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
|
||||
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
|
||||
@classmethod
|
||||
def refresh_env_variables(cls):
|
||||
"""
|
||||
Since the environment variables are assigned when the module is imported,
|
||||
This is a helper function to reload all the env variables from
|
||||
the environment variables.
|
||||
for example, after monkey patching the env variables in the unit test,
|
||||
you can call this function to reload the env variables.
|
||||
"""
|
||||
cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||
cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_enabled(cls) -> bool:
|
||||
"""Verifies device specs and availability of aiter main env variable."""
|
||||
return cls._AITER_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_fp8_enaled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls.is_linear_enabled()
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_rmsnorm_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fused_moe_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
||||
|
||||
@classmethod
|
||||
@ -694,25 +776,16 @@ class rocm_aiter_ops:
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mla_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mha_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_pa_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -124,8 +124,6 @@ def use_rocm_custom_paged_attention(
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> bool:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (rocm_aiter_ops.is_pa_attn_enabled())
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user