[ROCm] Serving Fails on Radeon Due to AITER Dtype Import (#30952)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-12-18 19:47:46 +08:00 committed by GitHub
parent f90d3636e2
commit 96bf50a2c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,14 +24,13 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def is_aiter_found_and_supported() -> bool:
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
return on_gfx9()
return False
def if_aiter_supported(func: Callable) -> Callable:
@ -43,17 +42,24 @@ def if_aiter_supported(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existence.
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
if on_gfx9():
return func(*args, **kwargs)
if is_aiter_found_and_supported():
return func(*args, **kwargs)
return None
return wrapper
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if is_aiter_found_and_supported():
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,