diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 16a4271655efa..0ddb82b7c3fc2 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import cast +import itertools import pytest import torch @@ -16,7 +16,13 @@ from vllm.compilation.activation_quant_fusion import ( from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + CompilationMode, + PassConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -25,7 +31,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, - cutlass_fp8_supported, + maybe_create_device_identity, ) from vllm.platforms import current_platform @@ -54,6 +60,8 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR, ) + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() def forward(self, x): y = self.silu_and_mul(x) @@ -61,7 +69,14 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): return x2 def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + ( + QUANT_OPS[kFp8StaticTensorSym] + if self.enable_quant_fp8_custom_op + else torch.ops.aten.reciprocal + ), + ] def ops_in_model_after(self): return [FUSED_OPS[kFp8StaticTensorSym]] @@ -77,6 +92,7 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): assert silu_and_mul_nvfp4_quant_supported self.silu_and_mul = SiluAndMul() + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() # create nvfp4 weight w = torch.rand((hidden_size, hidden_size)) @@ -101,7 +117,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): return out def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + QUANT_OPS[kNvfp4Quant], + ] def ops_in_model_after(self): return [FUSED_OPS[kNvfp4Quant]] @@ -110,67 +129,80 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) @pytest.mark.parametrize( - "model_class", - cast( - list[type], - [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] - if is_nvfp4_supported() - else [TestSiluMulFp8QuantModel], - ), + "model_class, enable_quant_fp8_custom_op, cuda_force_torch", + list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) + + [(TestSiluMulNvfp4QuantModel, False, False)], ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] -) @pytest.mark.skipif( envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" ) def test_fusion_silu_and_mul_quant( - num_tokens, hidden_size, dtype, model_class, cuda_force_torch + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], + enable_silu_mul_custom_op: bool, + enable_quant_fp8_custom_op: bool, + cuda_force_torch: bool, ): - if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: - pytest.skip("Duplicate tests for NVFP4") + if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): + pytest.skip("NVFP4 is not supported on this GPU.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) + maybe_create_device_identity() x = torch.rand(num_tokens, hidden_size * 2) # Reshape pass is needed for the fusion pass to work - config = VllmConfig() - config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=True, enable_noop=True) - ) - fusion_pass = ActivationQuantFusionPass(config) - - passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] - backend = TestBackend(*passes) - model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x) - - # First dimension dynamic - torch._dynamo.mark_dynamic(x, 0) - - result = model(x) - - model2 = torch.compile(model, backend=backend) - result2 = model2(x) - - # Check that it gives the same answer - if model_class == TestSiluMulFp8QuantModel: - atol, rtol = 1e-3, 1e-3 - elif model_class == TestSiluMulNvfp4QuantModel: - atol, rtol = 1e-1, 1e-1 - - torch.testing.assert_close( - result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol + custom_ops = [] + if enable_silu_mul_custom_op: + custom_ops.append("+silu_and_mul") + if enable_quant_fp8_custom_op: + custom_ops.append("+quant_fp8") + config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), ) - assert fusion_pass.matched_count == 1 + with set_current_vllm_config(config): + fusion_pass = ActivationQuantFusionPass(config) - # In pre-nodes, quant op should be present and fused kernels should not - backend.check_before_ops(model.ops_in_model_before()) + passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] + backend = TestBackend(*passes) + model = model_class( + hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x + ) - # In post-nodes, fused kernels should be present and quant op should not - backend.check_after_ops(model.ops_in_model_after()) + # First dimension dynamic + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Check that it gives the same answer + if model_class == TestSiluMulFp8QuantModel: + atol, rtol = 1e-3, 1e-3 + elif model_class == TestSiluMulNvfp4QuantModel: + atol, rtol = 1e-1, 1e-1 + + torch.testing.assert_close( + result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol + ) + + assert fusion_pass.matched_count == 1 + + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) + + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 7448bb122152d..b5fd67c5b027f 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -18,12 +18,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, kNvfp4Quant, - kStaticTensorScale, ) from vllm.platforms import current_platform from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -66,6 +66,8 @@ class ActivationQuantPattern(ABC): ) self.FUSED_OP = FUSED_OPS[self.quant_key] + self.silu_and_mul_matcher = MatcherSiluAndMul() + def empty_quant(self, *args, **kwargs): kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @@ -80,42 +82,38 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): Fusion for SiluMul+Fp8StaticQuant Pattern """ - def __init__(self, symmetric: bool = True): - quant_key = QuantKey( - dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric - ) - super().__init__(quant_key) + def __init__(self): + super().__init__(kFp8StaticTensorSym) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): - at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale - ) - return at2[1] + result_silu_mul = self.silu_and_mul_matcher(input) + result_quant = self.quant_matcher(result_silu_mul, scale) + return result_quant[0] def replacement( - result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): + d = input.shape[-1] // 2 + output_shape = input.shape[:-1] + (d,) + result = torch.empty( + output_shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( self.FUSED_OP, result=result, input=input, scale=scale ) return at[1] inputs = [ - self.empty_quant(5, 4), # result - empty_bf16(5, 4), # result_silu_mul - empty_bf16(5, 4), # input - empty_fp32(1, 1), # scale + *self.silu_and_mul_matcher.inputs(), # input + self.quant_matcher.inputs()[1], # scale ] + pattern(*inputs) register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) @@ -132,24 +130,22 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): def pattern( result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): - at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) - at2 = auto_functionalized( + result_silu_mul = self.silu_and_mul_matcher(input) + at = auto_functionalized( self.QUANT_OP, output=result, - input=at1[1], + input=result_silu_mul, output_scale=output_scale, input_scale=scale, ) - return at2[1], at2[2] + return at[1], at[2] def replacement( result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): @@ -165,7 +161,6 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): inputs = [ self.empty_quant(5, 32), # result empty_i32(128, 4), # output_scale - empty_bf16(5, 64), # result_silu_mul empty_bf16(5, 64), # input empty_fp32(1, 1), # scale ] diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index c4eb463de1d2e..383fe6033a6df 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -7,6 +7,7 @@ from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -31,6 +32,8 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + class MatcherCustomOp(ABC): def __init__(self, enabled: bool): @@ -206,3 +209,30 @@ class MatcherQuantFP8(MatcherCustomOp): return [input, self.empty_f32(1, 1)] return [input] + + +class MatcherSiluAndMul(MatcherCustomOp): + def __init__(self, enabled: bool | None = None): + if enabled is None: + enabled = SiluAndMul.enabled() + super().__init__(enabled) + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 4) + return [input] + + def forward_custom( + self, + x: torch.Tensor, + ) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + result = auto_functionalized(SILU_MUL_OP, result=out, input=x) + return result[1] + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return SiluAndMul.forward_native(x) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index fb1122708953c..3471ee327cf8c 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -80,7 +80,8 @@ class SiluAndMul(CustomOp): elif current_platform.is_cpu(): self._forward_method = self.forward_native - def forward_native(self, x: torch.Tensor) -> torch.Tensor: + @staticmethod + def forward_native(x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:]