mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 16:27:01 +08:00
[ROCm] Serving Fails on Radeon Due to AITER Dtype Import (#30952)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
f90d3636e2
commit
96bf50a2c0
@ -24,14 +24,13 @@ def is_aiter_found() -> bool:
|
|||||||
# we keep this global outside to not cause torch compile breaks.
|
# we keep this global outside to not cause torch compile breaks.
|
||||||
IS_AITER_FOUND = is_aiter_found()
|
IS_AITER_FOUND = is_aiter_found()
|
||||||
|
|
||||||
# Can't use dtypes.fp8 directly inside an op
|
|
||||||
# because it returns wrong result on gfx942.
|
|
||||||
# This is a workaround to get the correct FP8 dtype.
|
|
||||||
# This might because that the get_gfx() is wrapped as a custom op.
|
|
||||||
if IS_AITER_FOUND:
|
|
||||||
from aiter import dtypes
|
|
||||||
|
|
||||||
AITER_FP8_DTYPE = dtypes.fp8
|
def is_aiter_found_and_supported() -> bool:
|
||||||
|
if current_platform.is_rocm() and IS_AITER_FOUND:
|
||||||
|
from vllm.platforms.rocm import on_gfx9
|
||||||
|
|
||||||
|
return on_gfx9()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def if_aiter_supported(func: Callable) -> Callable:
|
def if_aiter_supported(func: Callable) -> Callable:
|
||||||
@ -43,17 +42,24 @@ def if_aiter_supported(func: Callable) -> Callable:
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
# checks the platform, device arch and aiter library existence.
|
# checks the platform, device arch and aiter library existence.
|
||||||
|
|
||||||
if current_platform.is_rocm() and IS_AITER_FOUND:
|
if is_aiter_found_and_supported():
|
||||||
from vllm.platforms.rocm import on_gfx9
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
if on_gfx9():
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# Can't use dtypes.fp8 directly inside an op
|
||||||
|
# because it returns wrong result on gfx942.
|
||||||
|
# This is a workaround to get the correct FP8 dtype.
|
||||||
|
# This might because that the get_gfx() is wrapped as a custom op.
|
||||||
|
if is_aiter_found_and_supported():
|
||||||
|
from aiter import dtypes
|
||||||
|
|
||||||
|
AITER_FP8_DTYPE = dtypes.fp8
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_fused_moe_impl(
|
def _rocm_aiter_fused_moe_impl(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user