From 96bf50a2c0142597e83de39503ccb7cfc7732d95 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 18 Dec 2025 19:47:46 +0800 Subject: [PATCH] [ROCm] Serving Fails on Radeon Due to AITER Dtype Import (#30952) Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c32bf04c71c1f..0eae279acf5be 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -24,14 +24,13 @@ def is_aiter_found() -> bool: # we keep this global outside to not cause torch compile breaks. IS_AITER_FOUND = is_aiter_found() -# Can't use dtypes.fp8 directly inside an op -# because it returns wrong result on gfx942. -# This is a workaround to get the correct FP8 dtype. -# This might because that the get_gfx() is wrapped as a custom op. -if IS_AITER_FOUND: - from aiter import dtypes - AITER_FP8_DTYPE = dtypes.fp8 +def is_aiter_found_and_supported() -> bool: + if current_platform.is_rocm() and IS_AITER_FOUND: + from vllm.platforms.rocm import on_gfx9 + + return on_gfx9() + return False def if_aiter_supported(func: Callable) -> Callable: @@ -43,17 +42,24 @@ def if_aiter_supported(func: Callable) -> Callable: def wrapper(*args, **kwargs): # checks the platform, device arch and aiter library existence. - if current_platform.is_rocm() and IS_AITER_FOUND: - from vllm.platforms.rocm import on_gfx9 - - if on_gfx9(): - return func(*args, **kwargs) + if is_aiter_found_and_supported(): + return func(*args, **kwargs) return None return wrapper +# Can't use dtypes.fp8 directly inside an op +# because it returns wrong result on gfx942. +# This is a workaround to get the correct FP8 dtype. +# This might because that the get_gfx() is wrapped as a custom op. +if is_aiter_found_and_supported(): + from aiter import dtypes + + AITER_FP8_DTYPE = dtypes.fp8 + + def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor,