diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 63726c07b7d18..c63196b893574 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -316,38 +316,39 @@ class W8A8BlockFp8LinearOp: assert self.act_quant_group_shape == GroupShape(1, 128) n, k = weight.shape - if input_scale is not None: - q_input = input_2d - # MI350 case uses triton kernel - if ( + use_triton = ( not current_platform.is_fp8_fnuz() and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) - ): + ) + + if use_triton: + gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale + else: + gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale + + if input_scale is not None: + q_input = input_2d + # MI350 case uses triton kernel + elif use_triton: q_input, input_scale = per_token_group_quant_fp8( input_2d, self.act_quant_group_shape.col, column_major_scales=False, use_ue8m0=False, ) - return rocm_aiter_ops.triton_gemm_a8w8_blockscale( - q_input, - weight, - input_scale, - weight_scale, - input_2d.dtype, - ) - # MI300 uses tuned AITER ASM/C++ kernel else: q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) - return rocm_aiter_ops.gemm_w8a8_blockscale( - q_input, - weight, - input_scale, - weight_scale, - input_2d.dtype, - ) + + return gemm_a8w8_blockscale_op( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + output_dtype=input_2d.dtype, + ) def _run_triton( self,