diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f25148abb619c..7fecda2166ef0 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -69,30 +69,67 @@ def cutlass_scaled_mm( def rocm_aiter_gemm_w8a8_blockscale_impl( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], + input_2d: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - import aiter as rocm_aiter + def is_aiter_triton_kernel_tuned(n, k): + return (n, k) in [ + (1024, 8192), + (2112, 7168), + (3072, 1536), + (32768, 8192), + (4096, 7168), + (4608, 7168), + (512, 7168), + (7168, 2048), + (7168, 256), + (8192, 1024), + (8192, 32768), + ] - return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + n, k = weight.shape + if input_scale is not None: + q_input = input_2d + elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k): + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + + # MI350 case uses triton kernel + q_input, input_scale = per_token_group_quant_fp8( + input_2d, + group_size, + column_major_scales=False, + use_ue8m0=False, + ) + else: + # MI300 uses tuned AITER ASM/C++ kernel + import aiter as rocm_aiter + from aiter import gemm_a8w8_blockscale, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + q_input, input_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 + ) + + return gemm_a8w8_blockscale( + q_input, weight, input_scale, weight_scale, dtype=output_dtype + ) def rocm_aiter_gemm_w8a8_blockscale_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], + input_2d: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - m = A.shape[0] - n = B.shape[0] - Y = torch.empty(m, n, dtype=output_dtype, device=A.device) - return Y + m = input_2d.shape[0] + n = weight.shape[0] + return torch.empty(m, n, dtype=output_dtype, device=input_2d.device) if current_platform.is_rocm(): @@ -101,15 +138,6 @@ if current_platform.is_rocm(): op_func=rocm_aiter_gemm_w8a8_blockscale_impl, fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, ) - if ( - envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz() - ): - import aiter as rocm_aiter - from aiter import get_hip_quant - - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) # TODO we should be able to change the type of block_size to GroupShape @@ -293,7 +321,9 @@ class W8A8BlockFp8LinearOp: ): output = self._run_deepgemm(input_2d, weight, weight_scale) else: - output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) + output = self.w8a8_blockscale_op( + input_2d, weight, weight_scale, input_scale + ) if bias is not None: output = output + bias @@ -322,7 +352,9 @@ class W8A8BlockFp8LinearOp: input_2d: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: + assert input_scale is None assert self.input_quant_op is not None q_input, input_scale = self.input_quant_op(input_2d) if self.is_hopper: @@ -350,17 +382,15 @@ class W8A8BlockFp8LinearOp: input_2d: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) - q_input, input_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 - ) return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - q_input, + input_2d, weight, input_scale, weight_scale, - list(self.weight_group_shape), + self.act_quant_group_shape.col, input_2d.dtype, ) @@ -369,7 +399,9 @@ class W8A8BlockFp8LinearOp: input_2d: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: + assert input_scale is None assert self.input_quant_op is not None q_input, input_scale = self.input_quant_op(input_2d) return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( @@ -391,6 +423,7 @@ class W8A8BlockFp8LinearOp: torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor | None, ], torch.Tensor, ], @@ -939,13 +972,11 @@ def requant_weight_ue8m0_inplace( def check_aiter_fp8_linear_support() -> bool: - """AITER is only supported on ROCm and only for FP8_FNUZ - and at the moment are MI300 series""" + """AITER is only supported on ROCm for MI3XX""" return ( current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz() )