diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 2ad34a79859a3..6b72c595cd779 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools + import pytest import torch import vllm.plugins +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.matcher_utils import QUANT_OPS @@ -152,13 +155,79 @@ GROUP_SHAPES = [ ] +class TestRmsnormGroupFp8QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, eps: float, **kwargs): + super().__init__() + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=GroupShape(1, 128), + cutlass_block_fp8_supported=False, + use_aiter_and_is_supported=True, + ) + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(3) + ] + + scale_hidden_size = (hidden_size + 128 - 1) // 128 + self.wscale = [ + torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32) + for _ in range(3) + ] + + self.norm_weight = [torch.ones(hidden_size) for _ in range(4)] + self.eps = eps + + def forward(self, x): + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) + y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps) + + x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0]) + # make sure resid is used for replacement to work + y2, resid = rocm_aiter_ops.rms_norm2d_with_add( + x2, resid, self.norm_weight[1], self.eps + ) + + x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1]) + + y3, resid = rocm_aiter_ops.rms_norm2d_with_add( + x3, resid, self.norm_weight[2], self.eps + ) + + x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2]) + + y4, resid = rocm_aiter_ops.rms_norm2d_with_add( + x4, resid, self.norm_weight[3], self.eps + ) + return y4 + + def ops_in_model_before(self): + return [ + torch.ops.vllm.rocm_aiter_rms_norm, + torch.ops.vllm.rocm_aiter_group_fp8_quant, + ] + + def ops_in_model_before_partial(self): + return [] + + def ops_in_model_after(self): + return [ + torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant, + torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant, + ] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("group_shape", GROUP_SHAPES) -@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) -@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) +@pytest.mark.parametrize( + "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op", + list(itertools.product([TestModel], [True, False], [True, False])) + + [(TestRmsnormGroupFp8QuantModel, False, False)], +) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -173,10 +242,14 @@ def test_fusion_rmsnorm_quant( num_tokens, eps, group_shape, + model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, cuda_force_torch, ): + if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND: + pytest.skip("AITER is not supported on this GPU.") + torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -209,12 +282,24 @@ def test_fusion_rmsnorm_quant( with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = RMSNormQuantFusionPass(vllm_config) + if model_class is TestRmsnormGroupFp8QuantModel: + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterRMSNormFp8GroupQuantFusionPass, + ) + + fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config) + else: + fusion_pass = RMSNormQuantFusionPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) - model = TestModel(hidden_size, eps, group_shape, cuda_force_torch) + model = model_class( + hidden_size=hidden_size, + eps=eps, + group_shape=group_shape, + cuda_force_torch=cuda_force_torch, + ) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) @@ -243,7 +328,10 @@ def test_fusion_rmsnorm_quant( # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). - if not enable_rms_norm_custom_op: + if ( + not enable_rms_norm_custom_op + and model_class is not TestRmsnormGroupFp8QuantModel + ): n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index c336a45955cb5..eb0dee8d4e399 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -7,6 +7,7 @@ import torch import vllm.envs as envs from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor +from vllm._aiter_ops import IS_AITER_FOUND from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.activation_quant_fusion import ( FUSED_OPS, @@ -24,6 +25,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, @@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): return [FUSED_OPS[kNvfp4Quant]] +class TestSiluMulGroupFp8QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=GroupShape(1, 128), + cutlass_block_fp8_supported=False, + use_aiter_and_is_supported=True, + ) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + + scale_hidden_size = (hidden_size + 128 - 1) // 128 + self.wscale = torch.rand( + (scale_hidden_size, scale_hidden_size), dtype=torch.float32 + ) + + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + + def forward(self, x): + y = self.silu_and_mul(x) + x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale) + return x2 + + def ops_in_model_before(self): + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + ] + + def ops_in_model_after(self): + return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant] + + @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): @pytest.mark.parametrize( "model_class, enable_quant_fp8_custom_op, cuda_force_torch", list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) - + [(TestSiluMulNvfp4QuantModel, False, False)], + + [ + (TestSiluMulNvfp4QuantModel, False, False), + (TestSiluMulGroupFp8QuantModel, False, False), + ], ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant( num_tokens: int, hidden_size: int, dtype: torch.dtype, - model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], + model_class: type[ + TestSiluMulFp8QuantModel + | TestSiluMulNvfp4QuantModel + | TestSiluMulGroupFp8QuantModel + ], enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, cuda_force_torch: bool, ): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): pytest.skip("NVFP4 is not supported on this GPU.") + if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND: + pytest.skip("AITER is not supported on this GPU.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant( ) with set_current_vllm_config(config): - fusion_pass = ActivationQuantFusionPass(config) + fusion_passes = [ActivationQuantFusionPass(config)] + if IS_AITER_FOUND: + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterSiluMulFp8GroupQuantFusionPass, + ) - passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] + fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] + + passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] backend = TestBackend(*passes) model = model_class( hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x @@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant( atol, rtol = 1e-3, 1e-3 elif model_class == TestSiluMulNvfp4QuantModel: atol, rtol = 1e-1, 1e-1 + elif model_class == TestSiluMulGroupFp8QuantModel: + atol, rtol = 5e-2, 5e-2 torch.testing.assert_close( result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol ) - assert fusion_pass.matched_count == 1 + assert sum([p.matched_count for p in fusion_passes]) == 1 # In pre-nodes, quant op should be present and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 94bbc9b00225e..010817e79a936 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -24,6 +24,15 @@ def is_aiter_found() -> bool: # we keep this global outside to not cause torch compile breaks. IS_AITER_FOUND = is_aiter_found() +# Can't use dtypes.fp8 directly inside an op +# because it returns wrong result on gfx942. +# This is a workaround to get the correct FP8 dtype. +# This might because that the get_gfx() is wrapped as a custom op. +if IS_AITER_FOUND: + from aiter import dtypes + + AITER_FP8_DTYPE = dtypes.fp8 + def if_aiter_supported(func: Callable) -> Callable: """Decorator that only executes the function if @@ -45,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable: return wrapper -def _rocm_aiter_group_fp8_quant_impl( - x: torch.Tensor, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" - from aiter import QuantType, dtypes, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) - return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8) - - -def _rocm_aiter_group_fp8_quant_fake( - x: torch.Tensor, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - from aiter import dtypes - - M, N = x.shape - x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device) - out_bs = torch.empty( - ( - M, - (N + group_size - 1) // group_size, - ), - dtype=torch.float32, - device=x.device, - ) - return x_fp8, out_bs - - def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -522,6 +501,142 @@ def _rocm_aiter_per_token_quant_fake( ) +def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=residual, + ) + return (x_quant, x_quant_scales, res) + + +def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = (M, (N + group_size - 1) // group_size) + return ( + torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + torch.empty_like(residual, device=residual.device), + ) + + +def _rocm_aiter_rmsnorm_fp8_group_quant_impl( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=None, + ) + return (x_quant, x_quant_scales) + + +def _rocm_aiter_rmsnorm_fp8_group_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = (M, (N + group_size - 1) // group_size) + return ( + torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + ) + + +def _rocm_aiter_group_fp8_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" + from aiter import QuantType, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE) + + +def _rocm_aiter_group_fp8_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + +def _rocm_aiter_act_mul_and_fp8_group_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant + + return act_mul_and_fp8_group_quant( + x, + activation="silu", + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + ) + + +def _rocm_aiter_act_mul_and_fp8_group_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + assert N % 2 == 0 + N_half = N // 2 + x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device) + out_bs = torch.empty( + ( + M, + (N_half + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -557,7 +672,7 @@ class rocm_aiter_ops: @if_aiter_supported def is_linear_fp8_enaled(cls) -> bool: """ "Verifies device specs and availability of env variable.""" - return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + return cls.is_linear_enabled() @classmethod @if_aiter_supported @@ -632,14 +747,6 @@ class rocm_aiter_ops: ) # register all the custom ops here - direct_register_custom_op( - op_name="rocm_aiter_group_fp8_quant", - op_func=_rocm_aiter_group_fp8_quant_impl, - mutates_args=[], - fake_impl=_rocm_aiter_group_fp8_quant_fake, - dispatch_key=current_platform.dispatch_key, - ) - direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=_rocm_aiter_asm_moe_tkw1_impl, @@ -699,27 +806,46 @@ class rocm_aiter_ops: direct_register_custom_op( op_name="rocm_aiter_gemm_a8w8_blockscale", op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, - mutates_args=[], fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=_rocm_aiter_rms_norm_impl, - mutates_args=[], fake_impl=_rocm_aiter_rms_norm_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, - mutates_args=[], fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fp8_group_quant", + op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", + op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_act_mul_and_fp8_group_quant", + op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, + fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_group_fp8_quant", + op_func=_rocm_aiter_group_fp8_quant_impl, + fake_impl=_rocm_aiter_group_fp8_quant_fake, + ) + direct_register_custom_op( op_name="rocm_aiter_per_tensor_quant", op_func=_rocm_aiter_per_tensor_quant_impl, diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 6848bfb6a3c53..4ebb386f75ed8 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,6 +5,7 @@ import functools from torch import fx as fx from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var from .post_cleanup import PostCleanupPass from .vllm_inductor_pass import VllmInductorPass +if rocm_aiter_ops.is_enabled(): + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterRMSNormFp8GroupQuantFusionPass, + RocmAiterSiluMulFp8GroupQuantFusionPass, + ) + if current_platform.is_cuda_alike(): from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import RMSNormQuantFusionPass @@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)] if self.pass_config.fuse_act_quant: self.passes += [ActivationQuantFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] if self.pass_config.fuse_attn_quant: self.passes += [AttnFusionPass(config)] diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py new file mode 100644 index 0000000000000..8b5db9de38181 --- /dev/null +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 +from vllm.compilation.activation_quant_fusion import ActivationQuantPattern +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fusion import empty_bf16 +from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherSiluAndMul +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + +AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default +AITER_RMS_ADD_GROUP_QUANT_OP = ( + torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default +) + +AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default +AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + +AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default +TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default + +FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + + +class AiterRMSFp8GroupQuantPattern: + """ + This pattern fuses aiter rms_norm & group fp8 quant custom + ops into an aiter rms_norm_group_fp8_quant op. + """ + + def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon) + + at2 = self.quant_op(at1, 128) + + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_GROUP_QUANT_OP( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + group_size=128, + ) + + return at[0], at[1] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class AiterFusedAddRMSFp8GroupQuantPattern: + """ + This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops + into a aiter rms_norm_with_add_group_fp8_quant op. + """ + + def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_ADD_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at2 = self.quant_op(at1[0], 128) + + # result, scale, residual + return at2[0], at2[1], at1[1] + + def replacement( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_ADD_GROUP_QUANT_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + group_size=128, + ) + + # result, scale, residual + return at[0], at[1], at[2] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports fused_add_rms_norm. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass" + ) + + # Make sure fused add patterns are before simple rms norm, + # as the latter is a subset of the former in torch ops + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + dynamic group fp8 quant + for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register( + self.patterns + ) + + AiterFusedAddRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, quant_op + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> Any: + fusion_patterns = [ + AiterRMSFp8GroupQuantPattern, + AiterFusedAddRMSFp8GroupQuantPattern, + ] + return self.hash_source(self, *fusion_patterns) + + +class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): + """ + This pattern fuses aiter silu_and_mul & group fp8 quant custom + ops into an aiter silu_and_mul_group_fp8_quant op. + """ + + def __init__(self, quant_op: OpOverload): + self.silu_and_mul_matcher = MatcherSiluAndMul() + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + ): + at1 = self.silu_and_mul_matcher(input) + at2 = self.quant_op(at1, 128) + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + ): + at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) + return at[0], at[1] + + inputs = [ + self.silu_and_mul_matcher.inputs()[0], + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" + ) + + for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self): + fusion_patterns = [ + ActivationQuantPattern, + AiterSiluMulFp8GroupQuantPattern, + ] + return VllmInductorPass.hash_source(self, *fusion_patterns) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f5200d7d34891..b459d5947863b 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -196,6 +196,39 @@ direct_register_custom_op( ) +def _triton_per_token_group_quant_fp8_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + return per_token_group_quant_fp8( + x, group_size, column_major_scales=False, use_ue8m0=False + ) + + +def _triton_per_token_group_quant_fp8_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + +direct_register_custom_op( + "triton_per_token_group_quant_fp8", + _triton_per_token_group_quant_fp8_impl, + fake_impl=_triton_per_token_group_quant_fp8_fake, +) + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: @@ -341,17 +374,15 @@ class W8A8BlockFp8LinearOp: if input_scale is not None: q_input = input_2d - # MI350 case uses triton kernel elif use_triton: - q_input, input_scale = per_token_group_quant_fp8( + q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8( input_2d, self.act_quant_group_shape.col, - column_major_scales=False, - use_ue8m0=False, ) - # MI300 uses tuned AITER ASM/C++ kernel else: - q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d) + q_input, input_scale = rocm_aiter_ops.group_fp8_quant( + input_2d, self.act_quant_group_shape.col + ) return gemm_a8w8_blockscale_op( q_input,