From a425dc256e4c2f76f98be136cd898b43f02e6a32 Mon Sep 17 00:00:00 2001 From: TJian Date: Fri, 14 Nov 2025 10:30:50 -0800 Subject: [PATCH] [Bugfix] [ROCm] [AITER]: Fix aiter block quant not compatible with torch compile dynamo (#28716) Signed-off-by: tjtanaa --- tests/rocm/aiter/test_grouped_quant.py | 137 ++++++++++++++++++ vllm/_aiter_ops.py | 48 +++++- .../layers/quantization/utils/fp8_utils.py | 2 +- 3 files changed, 180 insertions(+), 7 deletions(-) create mode 100644 tests/rocm/aiter/test_grouped_quant.py diff --git a/tests/rocm/aiter/test_grouped_quant.py b/tests/rocm/aiter/test_grouped_quant.py new file mode 100644 index 0000000000000..c7f0f1eda3558 --- /dev/null +++ b/tests/rocm/aiter/test_grouped_quant.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# This is a test for the AITER group_fp8_quant op. +# It tests if the AITER op is +# 1. correctly defined the relationship between +# implementation and fake function +# 2. can be used with torch.compile +# 3. can be used with CUDA graphs +# This file will be skipped if AITER is not installed +# and the platform is not ROCm. + +import importlib.util + +import pytest +import torch + +# this import statement is needed to ensure the ops are registered +from vllm._aiter_ops import rocm_aiter_ops +from vllm.platforms import current_platform + +# Check if aiter package is installed +aiter_available = importlib.util.find_spec("aiter") is not None + +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and aiter_available), + reason="AITER ops are only available on ROCm with aiter package installed", +) + + +def test_rocm_aiter_group_fp8_quant_fake_implementation(): + """Test that the fake implementation is correctly + defined for torch.ops.vllm.rocm_aiter_group_fp8_quant.""" + # Create test tensors + M = 128 + N = 4096 + group_size = 128 + + input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + + # Verify the op's fake implementation using torch.library.opcheck + # This checks that the fake function returns tensors with correct shapes and dtypes + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_group_fp8_quant, + (input_tensor, group_size), + test_utils=("test_faketensor",), + ) + + +def test_rocm_aiter_group_fp8_quant_torch_compile_with_cudagraph(): + """Test that rocm_aiter_ops.group_fp8_quant + with group size 128 can be used with + torch.compile in cudagraph mode.""" + # Create test tensors + M = 128 + N = 4096 + group_size = 128 + + input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + + # Define a function that uses the op + def group_fp8_quant_fn(x): + return rocm_aiter_ops.group_fp8_quant(x, group_size) + + # Compile with cudagraph mode + compiled_fn = torch.compile( + group_fp8_quant_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) + + # Run eager mode + x_fp8_eager, scales_eager = group_fp8_quant_fn(input_tensor) + + # Run compiled version (first run will trigger compilation) + x_fp8_compiled, scales_compiled = compiled_fn(input_tensor) + + # Verify shapes match + assert x_fp8_compiled.shape == x_fp8_eager.shape + assert scales_compiled.shape == scales_eager.shape + + # Verify expected shapes + assert x_fp8_compiled.shape == (M, N) + expected_scale_cols = (N + group_size - 1) // group_size + assert scales_compiled.shape == (M, expected_scale_cols) + + # Verify results match + assert torch.allclose( + x_fp8_compiled.to(torch.float32), + x_fp8_eager.to(torch.float32), + rtol=1e-2, + atol=1e-2, + ) + assert torch.allclose(scales_compiled, scales_eager, rtol=1e-3, atol=1e-3) + + # Test with different input (reusing compiled graph) + input_tensor_2 = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + x_fp8_eager_2, scales_eager_2 = group_fp8_quant_fn(input_tensor_2) + x_fp8_compiled_2, scales_compiled_2 = compiled_fn(input_tensor_2) + + # Verify second run also produces correct results + assert torch.allclose( + x_fp8_compiled_2.to(torch.float32), + x_fp8_eager_2.to(torch.float32), + rtol=1e-2, + atol=1e-2, + ) + assert torch.allclose(scales_compiled_2, scales_eager_2, rtol=1e-3, atol=1e-3) + + +def test_rocm_aiter_group_fp8_quant_different_shapes(): + """Test rocm_aiter_ops.group_fp8_quant with different input shapes.""" + group_size = 128 + + test_shapes = [ + (64, 2048), + (256, 8192), + (32, 1024), + (512, 4096), + ] + + for M, N in test_shapes: + input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda") + + x_fp8, scales = rocm_aiter_ops.group_fp8_quant(input_tensor, group_size) + + # Verify shapes + assert x_fp8.shape == (M, N) + expected_scale_cols = (N + group_size - 1) // group_size + assert scales.shape == (M, expected_scale_cols) + + # Verify dtypes + from aiter import dtypes + + assert x_fp8.dtype == dtypes.fp8 + assert scales.dtype == torch.float32 diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 7c35bf1857bae..e53e4ae6e5296 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -43,6 +43,36 @@ def if_aiter_supported(func: Callable) -> Callable: return wrapper +def _rocm_aiter_group_fp8_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" + from aiter import QuantType, dtypes, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8) + + +def _rocm_aiter_group_fp8_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter import dtypes + + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -512,6 +542,14 @@ class rocm_aiter_ops: ) # register all the custom ops here + direct_register_custom_op( + op_name="rocm_aiter_group_fp8_quant", + op_func=_rocm_aiter_group_fp8_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_group_fp8_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=_rocm_aiter_asm_moe_tkw1_impl, @@ -887,14 +925,12 @@ class rocm_aiter_ops: return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) @staticmethod - def per_1x128_fp8_quant( + def group_fp8_quant( input_2d: torch.Tensor, + group_size: int = 128, ) -> tuple[torch.Tensor, ...]: - """Only applies quantization method for fp8 data type only.""" - from aiter import QuantType, dtypes, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) - return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8) + assert group_size == 128, "Group size must be 128" + return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size) @staticmethod def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 541c6c631053d..ae63b4a767268 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -342,7 +342,7 @@ class W8A8BlockFp8LinearOp: ) # MI300 uses tuned AITER ASM/C++ kernel else: - q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) + q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d) return gemm_a8w8_blockscale_op( q_input,