[torch.compile] Enable silu_mul_fp8_quant fusion without custom ops enabled (#27146)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-10-22 12:22:39 +08:00 committed by GitHub
parent ceacedc1f9
commit ab3e80042e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 134 additions and 76 deletions

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import itertools
import pytest
import torch
@ -16,7 +16,13 @@ from vllm.compilation.activation_quant_fusion import (
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
CompilationMode,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@ -25,7 +31,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
@ -54,6 +60,8 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
@ -61,7 +69,14 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
return x2
def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
(
QUANT_OPS[kFp8StaticTensorSym]
if self.enable_quant_fp8_custom_op
else torch.ops.aten.reciprocal
),
]
def ops_in_model_after(self):
return [FUSED_OPS[kFp8StaticTensorSym]]
@ -77,6 +92,7 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
assert silu_and_mul_nvfp4_quant_supported
self.silu_and_mul = SiluAndMul()
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
# create nvfp4 weight
w = torch.rand((hidden_size, hidden_size))
@ -101,7 +117,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return out
def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
QUANT_OPS[kNvfp4Quant],
]
def ops_in_model_after(self):
return [FUSED_OPS[kNvfp4Quant]]
@ -110,67 +129,80 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class",
cast(
list[type],
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported()
else [TestSiluMulFp8QuantModel],
),
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4")
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
maybe_create_device_identity()
x = torch.rand(num_tokens, hidden_size * 2)
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
)
fusion_pass = ActivationQuantFusionPass(config)
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
custom_ops = []
if enable_silu_mul_custom_op:
custom_ops.append("+silu_and_mul")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
)
assert fusion_pass.matched_count == 1
with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config)
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
)
# In post-nodes, fused kernels should be present and quant op should not
backend.check_after_ops(model.ops_in_model_after())
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)
assert fusion_pass.matched_count == 1
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())
# In post-nodes, fused kernels should be present and quant op should not
backend.check_after_ops(model.ops_in_model_after())

View File

@ -18,12 +18,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -66,6 +66,8 @@ class ActivationQuantPattern(ABC):
)
self.FUSED_OP = FUSED_OPS[self.quant_key]
self.silu_and_mul_matcher = MatcherSiluAndMul()
def empty_quant(self, *args, **kwargs):
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@ -80,42 +82,38 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self, symmetric: bool = True):
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(quant_key)
def __init__(self):
super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
return at2[1]
result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]
def replacement(
result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP, result=result, input=input, scale=scale
)
return at[1]
inputs = [
self.empty_quant(5, 4), # result
empty_bf16(5, 4), # result_silu_mul
empty_bf16(5, 4), # input
empty_fp32(1, 1), # scale
*self.silu_and_mul_matcher.inputs(), # input
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
@ -132,24 +130,22 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
def pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
at2 = auto_functionalized(
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
output=result,
input=at1[1],
input=result_silu_mul,
output_scale=output_scale,
input_scale=scale,
)
return at2[1], at2[2]
return at[1], at[2]
def replacement(
result: torch.Tensor,
output_scale: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
@ -165,7 +161,6 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
inputs = [
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # result_silu_mul
empty_bf16(5, 64), # input
empty_fp32(1, 1), # scale
]

View File

@ -7,6 +7,7 @@ from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -31,6 +32,8 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
@ -206,3 +209,30 @@ class MatcherQuantFP8(MatcherCustomOp):
return [input, self.empty_f32(1, 1)]
return [input]
class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None):
if enabled is None:
enabled = SiluAndMul.enabled()
super().__init__(enabled)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 4)
return [input]
def forward_custom(
self,
x: torch.Tensor,
) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
return result[1]
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
return SiluAndMul.forward_native(x)

View File

@ -80,7 +80,8 @@ class SiluAndMul(CustomOp):
elif current_platform.is_cpu():
self._forward_method = self.forward_native
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@staticmethod
def forward_native(x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]