diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 299c8219120ae..5820832ed4860 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -380,6 +380,31 @@ def _rocm_aiter_gemm_a8w8_fake( return Y +def _rocm_aiter_triton_gemm_a8w8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + +def _rocm_aiter_triton_gemm_a8w8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + def _rocm_aiter_gemm_a8w8_blockscale_impl( A: torch.Tensor, B: torch.Tensor, @@ -964,6 +989,12 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_triton_gemm_a8w8_blockscale", + op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl, + fake_impl=_rocm_aiter_triton_gemm_a8w8_blockscale_fake, + ) + direct_register_custom_op( op_name="rocm_aiter_gemm_a8w8_blockscale", op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, @@ -1102,6 +1133,19 @@ class rocm_aiter_ops: ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype) + @staticmethod + def triton_gemm_a8w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale( + A, B, As, Bs, output_dtype + ) + @staticmethod def gemm_a8w8_blockscale( A: torch.Tensor, @@ -1373,19 +1417,6 @@ class rocm_aiter_ops: config=config, ) - @staticmethod - def triton_gemm_a8w8_blockscale( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, - ) -> torch.Tensor: - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - - return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) - @staticmethod def group_fp8_quant( input_2d: torch.Tensor,