vllm/vllm/compilation/activation_quant_fusion.py
Charlie Fu 7b2f28deba
[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)
Signed-off-by: charlifu <charlifu@amd.com>
2025-05-13 22:13:56 -07:00

89 lines
3.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
register_replacement)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def silu_mul_pattern_static(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
return at2[1]
def silu_mul_replacement_static(result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
result=result,
input=input,
scale=scale)
return at[1]
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp8(*args, **kwargs):
fp8 = current_platform.fp8_dtype()
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
class ActivationQuantFusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass")
inputs = [
empty_fp8(5, 4), # Quant output
empty_bf16(5, 4), # Silu_and_mul output
empty_bf16(5, 4), # Input
empty_fp32(1, 1) # Scale
]
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_act_quant_fusion")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
count)
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()