From e626d286f5ac997bca5bd635c3db9e457fe91df9 Mon Sep 17 00:00:00 2001 From: TJian Date: Sun, 27 Jul 2025 22:07:06 -0700 Subject: [PATCH] [FEAT] [ROCm] [AITER]: Add AITER HIP block quant kernel (#21242) --- .../layers/quantization/utils/fp8_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 8a7e809d082b1..2aece9a1dee06 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -82,6 +82,13 @@ if current_platform.is_rocm(): fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, dispatch_key=current_platform.dispatch_key, ) + if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()): + + import aiter as rocm_aiter + from aiter import get_hip_quant + + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) def dispatch_w8a8_blockscale_func( @@ -178,8 +185,12 @@ def apply_w8a8_block_fp8_linear( block_size, input.dtype) else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) + if use_aiter_and_is_supported: + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype)