# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch import vllm.config import vllm.plugins from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, CompilationMode, ModelConfig, PassConfig, VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassFP8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( FlashInferScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( ROCmScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 FP8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, ) from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( is_deep_gemm_supported, ) from ..utils import TestBlockFP8Layer, TestFP8Layer from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default # Kernel and group_shape combinations: (kernel, group_shape) # CUDA kernels CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [ # FlashInferScaledMMLinearKernel supports both per-tensor and per-token (FlashInferScaledMMLinearKernel, GroupShape.PER_TOKEN), (FlashInferScaledMMLinearKernel, GroupShape.PER_TENSOR), # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR), # PerTensorTorchScaledMMLinearKernel only supports per-tensor (PerTensorTorchScaledMMLinearKernel, GroupShape.PER_TENSOR), # ChannelWiseTorchScaledMMLinearKernel only supports per-token (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN), # Blockwise group shapes (no kernel abstraction) (None, GroupShape(1, 128)), (None, GroupShape(1, 64)), ] # ROCm kernels ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [ # ROCmScaledMMLinearKernel supports both per-tensor and per-token (ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN), (ROCmScaledMMLinearKernel, GroupShape.PER_TENSOR), # RowWiseTorchScaledMMLinearKernel only supports per-token (RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN), # ChannelWiseTorchScaledMMLinearKernel only supports per-token (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN), # Blockwise group shapes (no kernel abstraction) (None, GroupShape(1, 128)), (None, GroupShape(1, 64)), ] KERNEL_GROUPSHAPE_COMBINATIONS = ( CUDA_KERNEL_GROUPSHAPE_COMBINATIONS if current_platform.is_cuda() else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS ) # For Aiter tests we toggle use_aiter_quant_op AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [ # Per-token with ROCmScaledMMLinearKernel (ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN, True), (ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN, False), # Per-token with RowWiseTorchScaledMMLinearKernel (RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, True), (RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, False), # Per-token with ChannelWiseTorchScaledMMLinearKernel (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, True), (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, False), # Blockwise (no kernel abstraction) (None, GroupShape(1, 128), True), ] class TestModel(torch.nn.Module): def __init__( self, hidden_size: int, eps: float, force_kernel: FP8ScaledMMLinearKernel | None, group_shape: GroupShape, use_aiter_fusion: bool = False, use_aiter_quant: bool = False, *args, **kwargs, ): super().__init__(*args, **kwargs) self.fp8_linear_layers: list[torch.nn.Module] self.group_shape = group_shape self.use_aiter_quant_op = use_aiter_quant self.use_aiter_fusion = use_aiter_fusion self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.enable_rms_norm_custom_op = self.norm[0].enabled() # Determine if blockwise based on group_shape is_blockwise = group_shape.is_per_group() if is_blockwise: self._init_blockwise( hidden_size, group_shape, use_aiter_fusion, use_aiter_quant ) else: self._init_nonblockwise( hidden_size, group_shape, force_kernel, use_aiter_quant ) def _init_nonblockwise( self, hidden_size: int, group_shape: GroupShape, force_kernel: FP8ScaledMMLinearKernel | None, use_aiter_quant: bool, ): """Initialize non-blockwise (per-tensor/per-token) FP8 layers.""" is_static = group_shape == GroupShape.PER_TENSOR act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape) w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape) self.activation_quant_key = QuantKey( dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True ) self.weight_quant_key = QuantKey( dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True ) # Setup weight scales wscale_shape = (1,) if group_shape.is_per_tensor() else (hidden_size, 1) self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] self.act_scale = ( [torch.rand(1, dtype=torch.float32) for _ in range(3)] if is_static else [None for _ in range(3)] ) # Initialize weights (transposed for non-blockwise) self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(3) ] # Setup FP8 linear layers with kernel abstraction self.fp8_linear_layers = [ TestFP8Layer( self.activation_quant_key, self.weight_quant_key, self.w[i], self.wscale[i], input_scale=self.act_scale[i], force_kernel=force_kernel, ) for i in range(3) ] # Enable aiter quantization if requested for layer in self.fp8_linear_layers: layer.kernel.quant_fp8.use_aiter = use_aiter_quant self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ 0 ].is_quant_fp8_enabled() def _init_blockwise( self, hidden_size: int, group_shape: GroupShape, use_aiter_fusion: bool, use_aiter_quant: bool, ): """Initialize blockwise FP8 layers.""" act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape) self.activation_quant_key = QuantKey( dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True ) # Setup weight scales (for blockwise quantization) # Use aiter block size if aiter fusion is enabled scale_size = ( (hidden_size + 128 - 1) // 128 if use_aiter_fusion else hidden_size // group_shape[1] ) wscale_shape = (scale_size, scale_size) self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] # Initialize weights (transposed if using aiter, otherwise not) self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] if use_aiter_fusion: self.w = [w.t() for w in self.w] self.fp8_linear_layers = [ TestBlockFP8Layer( group_shape=group_shape, weight=self.w[i], weight_scale=self.wscale[i], input_scale=None, # Dynamic quantization for blockwise cutlass_block_fp8_supported=cutlass_block_fp8_supported(), use_aiter_and_is_supported=use_aiter_quant, ) for i in range(3) ] self.enable_quant_fp8_custom_op = ( False if use_aiter_quant else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled() ) def forward(self, x): # avoid having graph input be an arg to a pattern directly x = resid = torch.relu(x) y = self.norm[0](x) x2 = self.fp8_linear_layers[0](y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) x3 = self.fp8_linear_layers[1](y2) y3, resid = self.norm[2](x3, resid) # use resid here x4 = self.fp8_linear_layers[2](y3) y4, resid = self.norm[3](x4, resid) # use resid here return y4 def ops_in_model_before(self): if self.group_shape.is_per_group(): # Blockwise path if self.use_aiter_fusion and self.use_aiter_quant_op: return [rocm_aiter_ops.get_group_quant_op()] if self.use_aiter_fusion: return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] else: if self.use_aiter_quant_op: return [rocm_aiter_ops.get_per_token_quant_op()] # Common path return ( [QUANT_OPS[self.activation_quant_key]] if self.enable_quant_fp8_custom_op else [torch.ops.aten.reciprocal] ) def ops_in_model_after(self): if self.use_aiter_fusion: if self.group_shape.is_per_group(): # Blockwise aiter fusion from vllm.compilation.rocm_aiter_fusion import ( AiterFusedAddRMSFp8GroupQuantPattern, AiterRMSFp8GroupQuantPattern, ) return [ AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, AiterRMSFp8GroupQuantPattern.FUSED_OP, ] else: # Per-token aiter fusion from vllm.compilation.rocm_aiter_fusion import ( AiterFusedAddRMSNormDynamicQuantPattern, AiterRMSNormDynamicQuantPattern, ) return [ AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, AiterRMSNormDynamicQuantPattern.FUSED_OP, ] # Regular fusion return [ FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)], FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)], ] def ops_in_model_before_partial(self): return ( [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt] ) def _run_fusion_test( model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens, ): """Helper function for common fusion test logic. Must be called within vllm_config context. """ noop_pass = NoOpEliminationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) model_fused = torch.compile(model, backend=backend) result_fused = model_fused(x) model_unfused = torch.compile(model, backend=backend2) result_unfused = model_unfused(x) if dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) else: ATOL, RTOL = (1e-2, 1e-2) torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) assert fusion_pass.matched_count == 3 backend.check_before_ops(model.ops_in_model_before()) backend.check_after_ops(model.ops_in_model_after()) return backend, backend2 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" ) def test_fusion_rmsnorm_quant( dtype, hidden_size, num_tokens, eps, kernel_groupshape, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, ): force_kernel, group_shape = kernel_groupshape if not enable_quant_fp8_custom_op and group_shape.is_per_group(): pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") if group_shape == GroupShape(1, 64) and ( cutlass_block_fp8_supported() or is_deep_gemm_supported() ): pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm") custom_ops = [] if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") if enable_quant_fp8_custom_op: custom_ops.append("+quant_fp8") vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops, pass_config=PassConfig( fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True ), ), ) with vllm.config.set_current_vllm_config(vllm_config): # Setup device before model creation torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) fusion_pass = RMSNormQuantFusionPass(vllm_config) model = TestModel( hidden_size=hidden_size, eps=eps, force_kernel=force_kernel, group_shape=group_shape, use_aiter_fusion=False, use_aiter_quant=False, ) backend, _ = _run_fusion_test( model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens ) backend.check_before_ops( model.ops_in_model_before_partial(), fully_replaced=False ) # If RMSNorm custom op is disabled (native/torch impl used), # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). if not enable_rms_norm_custom_op: n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 assert n_add_nodes(backend.graph_post_pass) == 2 @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize( "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS ) @pytest.mark.skipif( (not current_platform.is_rocm() or not IS_AITER_FOUND), reason="Only test on ROCm with aiter package installed", ) def test_aiter_fusion_rmsnorm_quant( dtype: torch.dtype, hidden_size: int, num_tokens: int, eps: float, kernel_groupshape_quant: tuple, monkeypatch: pytest.MonkeyPatch, ): force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"], pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True), ), ) with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass m.setenv("VLLM_ROCM_USE_AITER", "1") rocm_aiter_ops.refresh_env_variables() torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) fusion_pass = RocmAiterRMSNormFusionPass(vllm_config) model = TestModel( hidden_size=hidden_size, eps=eps, force_kernel=force_kernel, group_shape=group_shape, use_aiter_fusion=True, # Always use aiter fusion ops in aiter test use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization ) _run_fusion_test( model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens )