diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 925d0a516ce6..32192225f61e 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -83,26 +83,11 @@ def block_dequant( if current_platform.is_rocm(): - from triton.language import core - - # NOTE: This can be removed when hip.libdevice.round() is available. - @core.extern - def round_f32(arg0, _builder=None): - return core.extern_elementwise( - "", - "", - [arg0], - { - (core.dtype("fp32"),): ("llvm.round", core.dtype("fp32")), - (core.dtype("fp64"),): ("llvm.round", core.dtype("fp64")), - }, - is_pure=True, - _builder=_builder, - ) @triton.jit def round_int8(x): - return round_f32(x).to(tl.int8) + return tl.extra.hip.libdevice.round(x).to(tl.int8) + else: @triton.jit