[ROCm] Add env to enable/disable aiter triton gemm (#28321)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-11-08 20:27:00 -10:00 committed by GitHub
parent e5e9067e61
commit de2b78305f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 0 deletions

View File

@ -113,6 +113,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
@ -944,6 +945,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
in ("true", "1")
),
# Whether to use aiter triton kernels for gemm ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1")
),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
@ -1586,6 +1592,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_AITER_TRITON_GEMM",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING",

View File

@ -106,6 +106,7 @@ def default_unquantized_gemm(
def use_aiter_triton_gemm(n, m, k, dtype):
if (
envs.VLLM_ROCM_USE_AITER == 0
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
# MI300's - fp8nuz=True
or current_platform.is_fp8_fnuz()
or dtype not in [torch.float16, torch.bfloat16]