diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6b72c595cd779..7755e9f9b7380 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools import pytest import torch @@ -53,37 +52,61 @@ class TestModel(torch.nn.Module): hidden_size: int, eps: float, group_shape: GroupShape, - cuda_force_torch: bool, + use_aiter: bool = False, + cuda_force_torch: bool = False, + use_aiter_quant_op: bool = True, *args, **kwargs, ): super().__init__(*args, **kwargs) + self.use_aiter = use_aiter + self.use_aiter_quant_op = use_aiter_quant_op self.cuda_force_torch = cuda_force_torch + self.group_shape = group_shape + self.enable_quant_fp8_custom_op = None # Will be set later if applicable + self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - if group_shape.is_per_group(): - self.wscale = [ - torch.rand( - (hidden_size // group_shape[1], hidden_size // group_shape[1]), - dtype=torch.float32, - ) - for _ in range(3) - ] - else: - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - static = group_shape == GroupShape.PER_TENSOR + + # Setup quantization scale descriptor + static = group_shape == GroupShape.PER_TENSOR and not use_aiter quant_scale = ScaleDesc(torch.float32, static, group_shape) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + + # Setup scales if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: self.scale = [None for _ in range(3)] + + # Setup weights self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] - if not group_shape.is_per_group(): + if not group_shape.is_per_group() or use_aiter: self.w = [self.w[0].t() for _ in range(3)] + # Setup weight scales if group_shape.is_per_group(): + scale_size = ( + (hidden_size + 128 - 1) // 128 + if use_aiter + else hidden_size // group_shape[1] + ) + wscale_shape: tuple[int, ...] = (scale_size, scale_size) + else: + wscale_shape = (1,) + self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] + + # Setup FP8 linear operation + is_per_group = group_shape.is_per_group() + if is_per_group and use_aiter: + self.fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=group_shape, + use_aiter_and_is_supported=use_aiter_quant_op, + ) + # AITER blockwise doesn't use enable_quant_fp8_custom_op + elif is_per_group: self.fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(group_shape[1], group_shape[1]), act_quant_group_shape=group_shape, @@ -91,6 +114,13 @@ class TestModel(torch.nn.Module): use_aiter_and_is_supported=False, ) self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() + elif use_aiter: + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + act_quant_group_shape=group_shape, + ) + self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() else: with override_cutlass_fp8_supported(not cuda_force_torch): self.fp8_linear = Fp8LinearOp( @@ -100,7 +130,6 @@ class TestModel(torch.nn.Module): self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.group_shape = group_shape def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -126,19 +155,49 @@ class TestModel(torch.nn.Module): y4, resid = self.norm[3](x4, resid) # use resid here return y4 + def ops_in_model_before(self): + if ( + self.use_aiter + and self.group_shape.is_per_group() + and current_platform.is_fp8_fnuz() + ): + return [rocm_aiter_ops.get_group_quant_op()] + if self.use_aiter and self.group_shape.is_per_group(): + return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] + if self.use_aiter and self.use_aiter_quant_op: + return [rocm_aiter_ops.get_per_token_quant_op()] + if self.use_aiter: + return [QUANT_OPS[self.quant_key]] + if self.enable_quant_fp8_custom_op: + return [QUANT_OPS[self.quant_key]] + return [torch.ops.aten.reciprocal] + def ops_in_model_after(self): + if self.use_aiter and self.group_shape.is_per_group(): + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSFp8GroupQuantPattern, + AiterRMSFp8GroupQuantPattern, + ) + + return [ + AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, + AiterRMSFp8GroupQuantPattern.FUSED_OP, + ] + if self.use_aiter: + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSNormDynamicQuantPattern, + AiterRMSNormDynamicQuantPattern, + ) + + return [ + AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, + AiterRMSNormDynamicQuantPattern.FUSED_OP, + ] return [ FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], ] - def ops_in_model_before(self): - return ( - [QUANT_OPS[self.quant_key]] - if self.enable_quant_fp8_custom_op - else [torch.ops.aten.reciprocal] - ) - def ops_in_model_before_partial(self): return ( [RMS_OP, RMS_ADD_OP] @@ -155,67 +214,45 @@ GROUP_SHAPES = [ ] -class TestRmsnormGroupFp8QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, eps: float, **kwargs): - super().__init__() - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(128, 128), - act_quant_group_shape=GroupShape(1, 128), - cutlass_block_fp8_supported=False, - use_aiter_and_is_supported=True, - ) - self.w = [ - torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(3) - ] +def _run_fusion_test( + model, + fusion_pass, + vllm_config, + dtype, + hidden_size, + num_tokens, +): + """Helper function for common fusion test logic. - scale_hidden_size = (hidden_size + 128 - 1) // 128 - self.wscale = [ - torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32) - for _ in range(3) - ] + Must be called within vllm_config context. + """ + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - self.norm_weight = [torch.ones(hidden_size) for _ in range(4)] - self.eps = eps + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend2 = TestBackend(noop_pass, cleanup_pass) - def forward(self, x): - # avoid having graph input be an arg to a pattern directly - x = resid = torch.relu(x) - y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps) + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) - x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0]) - # make sure resid is used for replacement to work - y2, resid = rocm_aiter_ops.rms_norm2d_with_add( - x2, resid, self.norm_weight[1], self.eps - ) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) - x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1]) + model_unfused = torch.compile(model, backend=backend2) + result_unfused = model_unfused(x) - y3, resid = rocm_aiter_ops.rms_norm2d_with_add( - x3, resid, self.norm_weight[2], self.eps - ) + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) - x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2]) + torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) - y4, resid = rocm_aiter_ops.rms_norm2d_with_add( - x4, resid, self.norm_weight[3], self.eps - ) - return y4 + assert fusion_pass.matched_count == 3 + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) - def ops_in_model_before(self): - return [ - torch.ops.vllm.rocm_aiter_rms_norm, - torch.ops.vllm.rocm_aiter_group_fp8_quant, - ] - - def ops_in_model_before_partial(self): - return [] - - def ops_in_model_after(self): - return [ - torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant, - torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant, - ] + return backend, backend2 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("group_shape", GROUP_SHAPES) -@pytest.mark.parametrize( - "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op", - list(itertools.product([TestModel], [True, False], [True, False])) - + [(TestRmsnormGroupFp8QuantModel, False, False)], -) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant( num_tokens, eps, group_shape, - model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, cuda_force_torch, ): - if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND: - pytest.skip("AITER is not supported on this GPU.") - - torch.set_default_device("cuda") - torch.set_default_dtype(dtype) - torch.manual_seed(1) - maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - if not enable_quant_fp8_custom_op and group_shape.is_per_group(): pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") - # Skip test for 64-bit group shape when running with cutlass or deepgemm if group_shape == GroupShape(1, 64) and ( cutlass_block_fp8_supported() or is_deep_gemm_supported() ): @@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant( custom_ops.append("+rms_norm") if enable_quant_fp8_custom_op: custom_ops.append("+quant_fp8") + vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( @@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant( ), ), ) + with vllm.config.set_current_vllm_config(vllm_config): - # Reshape pass is needed for the fusion pass to work - noop_pass = NoOpEliminationPass(vllm_config) - if model_class is TestRmsnormGroupFp8QuantModel: - from vllm.compilation.rocm_aiter_fusion import ( - RocmAiterRMSNormFp8GroupQuantFusionPass, - ) + # Setup device before model creation + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + maybe_create_device_identity() - fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config) - else: - fusion_pass = RMSNormQuantFusionPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) - backend2 = TestBackend(noop_pass, cleanup_pass) - model = model_class( + fusion_pass = RMSNormQuantFusionPass(vllm_config) + model = TestModel( hidden_size=hidden_size, eps=eps, group_shape=group_shape, + use_aiter=False, cuda_force_torch=cuda_force_torch, ) - # First dimension dynamic - x = torch.rand(num_tokens, hidden_size) - torch._dynamo.mark_dynamic(x, 0) - model_fused = torch.compile(model, backend=backend) - result_fused = model_fused(x) - - model_unfused = torch.compile(model, backend=backend2) - result_unfused = model_unfused(x) - - if dtype == torch.float16: - ATOL, RTOL = (2e-3, 2e-3) - else: - ATOL, RTOL = (1e-2, 1e-2) - - torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) - - assert fusion_pass.matched_count == 3 - backend.check_before_ops(model.ops_in_model_before()) + backend, _ = _run_fusion_test( + model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens + ) backend.check_before_ops( model.ops_in_model_before_partial(), fully_replaced=False ) - backend.check_after_ops(model.ops_in_model_after()) # If RMSNorm custom op is disabled (native/torch impl used), # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). - if ( - not enable_rms_norm_custom_op - and model_class is not TestRmsnormGroupFp8QuantModel - ): + if not enable_rms_norm_custom_op: n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 assert n_add_nodes(backend.graph_post_pass) == 2 + + +GROUP_SHAPE_QUANT_OPS_MATCHS = [ + (GroupShape.PER_TOKEN, True), + (GroupShape.PER_TOKEN, False), + (GroupShape(1, 128), True), +] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize( + "group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS +) +@pytest.mark.skipif( + (not current_platform.is_rocm() or not IS_AITER_FOUND), + reason="Only test on ROCm with aiter package installed", +) +def test_aiter_fusion_rmsnorm_quant( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + group_shape: GroupShape, + use_aiter_quant_op: bool, + monkeypatch: pytest.MonkeyPatch, +): + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass + + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + maybe_create_device_identity() + + fusion_pass = RocmAiterRMSNormFusionPass(vllm_config) + model = TestModel( + hidden_size=hidden_size, + eps=eps, + group_shape=group_shape, + use_aiter=True, + use_aiter_quant_op=use_aiter_quant_op, + ) + + _run_fusion_test( + model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens + ) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 03e3bb7594910..299c8219120ae 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -4,6 +4,7 @@ import functools from collections.abc import Callable import torch +from torch._ops import OpOverload import vllm.envs as envs from vllm.platforms import current_platform @@ -433,16 +434,16 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( from aiter import rmsnorm2d_fwd_with_add residual_out = torch.empty_like(residual) - output = torch.empty_like(x) + out = torch.empty_like(x) rmsnorm2d_fwd_with_add( - output, # output + out, # output x, # input residual, # residual input residual_out, # residual output weight, variance_epsilon, ) - return output, residual_out + return out, residual_out def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( @@ -451,7 +452,84 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( weight: torch.Tensor, variance_epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(x), torch.empty_like(residual) + residual_out = torch.empty_like(residual) + out = torch.empty_like(x) + return out, residual_out + + +def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import aiter as rocm_aiter + + assert quant_dtype in [torch.int8, _FP8_DTYPE] + + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + residual_out = torch.empty_like(x) + + rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant( + out, + x, + residual, + residual_out, + y_scale, + weight, + epsilon, + use_model_sensitive_rmsnorm=0, + ) + + return out, residual_out, y_scale + + +def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + residual_out = torch.empty_like(x) + + return out, residual_out, y_scale + + +def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + import aiter as rocm_aiter + + assert quant_dtype in [torch.int8, _FP8_DTYPE] + + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + + rocm_aiter.rmsnorm2d_fwd_with_dynamicquant( + out, x, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0 + ) + + return out, y_scale + + +def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + + return out, y_scale def _rocm_aiter_per_tensor_quant_impl( @@ -527,7 +605,11 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( dtype_quant=AITER_FP8_DTYPE, res1=residual, ) - return (x_quant, x_quant_scales, res) + return ( + x_quant, + res, + x_quant_scales, + ) def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( @@ -541,8 +623,8 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( scale_shape = (M, (N + group_size - 1) // group_size) return ( torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), - torch.empty(scale_shape, dtype=torch.float32, device=x.device), torch.empty_like(residual, device=residual.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), ) @@ -901,6 +983,20 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_dynamic_quant", + op_func=_rocm_aiter_rmsnorm_fused_dynamic_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_fused_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant", + op_func=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_rmsnorm_fp8_group_quant", op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl, @@ -936,13 +1032,54 @@ class rocm_aiter_ops: direct_register_custom_op( op_name="rocm_aiter_per_token_quant", op_func=_rocm_aiter_per_token_quant_impl, - mutates_args=["scale"], fake_impl=_rocm_aiter_per_token_quant_fake, dispatch_key=current_platform.dispatch_key, ) _OPS_REGISTERED = True + @staticmethod + def get_rmsnorm_fused_add_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + + @staticmethod + def get_rmsnorm_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rms_norm.default + + @staticmethod + def get_rmsnorm_fused_add_dynamic_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default + + @staticmethod + def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default + + @staticmethod + def get_rmsnorm_group_fused_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default + + @staticmethod + def get_rmsnorm_group_add_fused_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default + + @staticmethod + def get_per_token_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_per_token_quant.default + + @staticmethod + def get_group_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_group_fp8_quant.default + + @staticmethod + def get_act_mul_fused_fp8_group_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + + @staticmethod + def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + @staticmethod def rms_norm2d_with_add( x: torch.Tensor, @@ -954,12 +1091,6 @@ class rocm_aiter_ops: x, residual, weight, variance_epsilon ) - @staticmethod - def rms_norm( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float - ) -> torch.Tensor: - return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) - @staticmethod def gemm_a8w8( A: torch.Tensor, diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index ec9ed34f561b4..7301aa3e5932d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -6,11 +6,13 @@ import torch from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, _normalize_quant_group_shape, kFp8Dynamic64Sym, @@ -150,26 +152,50 @@ class MatcherRotaryEmbedding(MatcherCustomOp): class MatcherRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: bool | None = None): + def __init__( + self, + epsilon: float, + enabled: bool | None = None, + match_rocm_aiter: bool = False, + ): if enabled is None: enabled = RMSNorm.enabled() super().__init__(enabled) self.epsilon = epsilon + self._rmsnorm_op = RMS_OP + self.match_rocm_aiter = match_rocm_aiter + + if match_rocm_aiter: + self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op() def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) weight = self.empty(16) return [input, weight] + def forward_rocm_aiter( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return self._rmsnorm_op( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + ) + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, ) -> torch.Tensor: + if self.match_rocm_aiter: + return self.forward_rocm_aiter(input, weight) + result = torch.empty_like(input) _, result = auto_functionalized( - RMS_OP, + self._rmsnorm_op, result=result, input=input, weight=weight, @@ -189,12 +215,23 @@ class MatcherRMSNorm(MatcherCustomOp): class MatcherFusedAddRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: bool | None = None): + def __init__( + self, + epsilon: float, + enabled: bool | None = None, + match_rocm_aiter: bool = False, + ): if enabled is None: enabled = RMSNorm.enabled() super().__init__(enabled) self.epsilon = epsilon + self.match_rocm_aiter = match_rocm_aiter + + self._rmsnorm_op = RMS_ADD_OP + + if match_rocm_aiter: + self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op() def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) @@ -202,14 +239,27 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): residual = self.empty(5, 16) return [input, weight, residual] + def forward_rocm_aiter( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self._rmsnorm_op( + x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon + ) + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.match_rocm_aiter: + return self.forward_rocm_aiter(input, weight, residual) + _, result, residual = auto_functionalized( - RMS_ADD_OP, + self._rmsnorm_op, input=input, residual=residual, weight=weight, @@ -236,22 +286,46 @@ class MatcherQuantFP8(MatcherCustomOp): enabled: bool | None = None, has_col_major_scales: bool = False, is_e8m0: bool = False, + match_rocm_aiter: bool = False, ): if enabled is None: enabled = QuantFP8.enabled() super().__init__(enabled) self.quant_key = quant_key - assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" - self.QUANT_OP = QUANT_OPS[quant_key] - self.has_col_major_scales = has_col_major_scales self.is_e8m0 = is_e8m0 + self.match_rocm_aiter = match_rocm_aiter + + if match_rocm_aiter: + assert not quant_key.scale.group_shape.is_per_tensor(), ( + "ROCm aiter fusion pass does not support per tensor quantization" + ) + if quant_key.scale.group_shape.is_per_token(): + self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op() + else: + assert quant_key.scale.group_shape.col == 128, ( + "ROCm aiter fusion pass currently supports " + "quantization operation with group_size 128" + ) + if current_platform.is_fp8_fnuz(): + self.QUANT_OP = rocm_aiter_ops.get_group_quant_op() + else: + self.QUANT_OP = ( + torch.ops.vllm.triton_per_token_group_quant_fp8.default + ) + + else: + assert quant_key in QUANT_OPS, ( + f"unsupported quantization scheme {quant_key}" + ) + self.QUANT_OP = QUANT_OPS[quant_key] + + assert quant_key.dtype == current_platform.fp8_dtype(), ( + "Only QuantFP8 supported by" + ) + assert quant_key.scale2 is None - assert quant_key.dtype == current_platform.fp8_dtype(), ( - "Only QuantFP8 supported by" - ) - assert quant_key.scale2 is None self.quant_fp8 = QuantFP8( quant_key.scale.static, quant_key.scale.group_shape, @@ -259,11 +333,29 @@ class MatcherQuantFP8(MatcherCustomOp): use_ue8m0=is_e8m0, ) + def forward_rocm_aiter( + self, + input: torch.Tensor, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + quant_key_group_shape = self.quant_key.scale.group_shape + if quant_key_group_shape == GroupShape.PER_TOKEN: + return self.QUANT_OP( + x=input, + quant_dtype=self.quant_key.dtype, + scale=scale, + ) + else: + return self.QUANT_OP(input, quant_key_group_shape.col) + def forward_custom( self, input: torch.Tensor, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.match_rocm_aiter: + return self.forward_rocm_aiter(input, scale) + result = torch.empty( input.shape, device=input.device, dtype=self.quant_key.dtype ) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 4ebb386f75ed8..4c2dee505a941 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -16,7 +16,7 @@ from .vllm_inductor_pass import VllmInductorPass if rocm_aiter_ops.is_enabled(): from vllm.compilation.rocm_aiter_fusion import ( - RocmAiterRMSNormFp8GroupQuantFusionPass, + RocmAiterRMSNormFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, ) @@ -117,7 +117,9 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] if rocm_aiter_ops.is_enabled(): - self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)] + self.passes += [ + RocmAiterRMSNormFusionPass(config), + ] if self.pass_config.fuse_act_quant: self.passes += [ActivationQuantFusionPass(config)] if rocm_aiter_ops.is_enabled(): diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index 8b5db9de38181..f66bb76b97f05 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -9,60 +9,195 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 +from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.activation_quant_fusion import ActivationQuantPattern from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, +) from vllm.platforms import current_platform -from .fusion import empty_bf16 +from .fusion import ( + FusedRMSQuantKey, +) from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherSiluAndMul +from .matcher_utils import ( + MatcherFusedAddRMSNorm, + MatcherQuantFP8, + MatcherRMSNorm, + MatcherSiluAndMul, +) from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() -AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default -AITER_RMS_ADD_GROUP_QUANT_OP = ( - torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default -) -AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default -AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default +class AiterRMSNormQuantPattern: + def __init__( + self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True + ): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype -AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default -TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default - -FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + self.rmsnorm_matcher = ( + MatcherRMSNorm(epsilon, match_rocm_aiter=True) + if not key.fused_add + else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True) + ) + self.quant_matcher = MatcherQuantFP8( + key.quant, + match_rocm_aiter=match_aiter_quant, + ) -class AiterRMSFp8GroupQuantPattern: +class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): + """AITER RMSNorm + Dynamic Quantization pattern.""" + + FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + match_aiter_quant: bool = True, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) + + def register(self, pm_pass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + ): + result_rms = self.rmsnorm_matcher(input, weight) + result, scale = self.quant_matcher(result_rms) + return result, scale + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + ): + result = self.FUSED_OP( + x=input, + weight=weight, + epsilon=self.epsilon, + quant_dtype=self.quant_dtype, + ) + + return result[0], result[1] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + +class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): + """AITER RMSNorm Fused Add + Dynamic Quantization pattern.""" + + FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + match_aiter_quant: bool = True, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) + + def register(self, pm_pass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ): + result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) + + return result, residual_out, scale + + def replacement( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ): + result = self.FUSED_OP( + x=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + quant_dtype=self.quant_dtype, + ) + + return result[0], result[1], result[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + +class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): """ This pattern fuses aiter rms_norm & group fp8 quant custom ops into an aiter rms_norm_group_fp8_quant op. """ - def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): - self.epsilon = epsilon - self.quant_dtype = quant_dtype - self.quant_op = quant_op + FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + match_aiter_quant: bool = True, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, weight: torch.Tensor, ): - at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon) - - at2 = self.quant_op(at1, 128) - - return at2[0], at2[1] + result_rms = self.rmsnorm_matcher(input, weight) + result, scale = self.quant_matcher(result_rms) + return result, scale def replacement( input: torch.Tensor, weight: torch.Tensor, ): - at = AITER_RMS_GROUP_QUANT_OP( + at = self.FUSED_OP( x=input, weight=weight, variance_epsilon=self.epsilon, @@ -71,49 +206,52 @@ class AiterRMSFp8GroupQuantPattern: return at[0], at[1] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - ] - - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass + ) -class AiterFusedAddRMSFp8GroupQuantPattern: +class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): """ This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops into a aiter rms_norm_with_add_group_fp8_quant op. """ - def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): - self.epsilon = epsilon - self.quant_dtype = quant_dtype - self.quant_op = quant_op + FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + match_aiter_quant: bool = True, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, ): - at1 = AITER_RMS_ADD_OP( - x=input, - residual=residual, - weight=weight, - variance_epsilon=self.epsilon, - ) + result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) - at2 = self.quant_op(at1[0], 128) - - # result, scale, residual - return at2[0], at2[1], at1[1] + return result, residual_out, scale def replacement( input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, ): - at = AITER_RMS_ADD_GROUP_QUANT_OP( + at = self.FUSED_OP( x=input, residual=residual, weight=weight, @@ -124,18 +262,15 @@ class AiterFusedAddRMSFp8GroupQuantPattern: # result, scale, residual return at[0], at[1], at[2] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - ] - - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass + ) -class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): +class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass): """ - This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + This pass fuses aiter rms_norm & vllm/aiter quant custom ops + into a fused rms_norm_quant op. It also supports fused_add_rms_norm. """ @@ -144,20 +279,33 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass" + pass_name="rocm_aiter_rms_norm_quant_fusion_pass" ) # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + dynamic group fp8 quant - for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: - AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register( - self.patterns - ) + # Fuse aiter rms_norm + aiter dynamic group fp8 quant + AiterRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, GroupShape(1, 128) + ).register(self.patterns) - AiterFusedAddRMSFp8GroupQuantPattern( - epsilon, FP8_DTYPE, quant_op + # Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant + AiterFusedAddRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, GroupShape(1, 128) + ).register(self.patterns) + + for match_aiter_quant in [True, False]: + # Fuse aiter rms_norm + (aiter / vllm built-in) + # dynamic per-token fp8 quant + AiterRMSNormDynamicQuantPattern( + epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant + ).register(self.patterns) + + # Fuse aiter fused_add_rms_norm + (aiter / vllm built-in) + # dynamic per-token fp8 quant + AiterFusedAddRMSNormDynamicQuantPattern( + epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant ).register(self.patterns) self.dump_patterns(config, self.patterns) @@ -169,6 +317,8 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): def uuid(self) -> Any: fusion_patterns = [ + AiterRMSNormDynamicQuantPattern, + AiterFusedAddRMSNormDynamicQuantPattern, AiterRMSFp8GroupQuantPattern, AiterFusedAddRMSFp8GroupQuantPattern, ] @@ -181,6 +331,8 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): ops into an aiter silu_and_mul_group_fp8_quant op. """ + FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op() + def __init__(self, quant_op: OpOverload): self.silu_and_mul_matcher = MatcherSiluAndMul() self.quant_op = quant_op @@ -196,7 +348,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): def replacement( input: torch.Tensor, ): - at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) + at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) return at[0], at[1] inputs = [ @@ -216,6 +368,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op() + TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default + + QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP] + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -224,7 +381,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" ) - for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + for quant_op in self.QUANT_OPS: AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) self.dump_patterns(config, self.patterns)