mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 23:36:32 +08:00
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Micah Williamson <micah.williamson@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com> Co-authored-by: wuhuikx <hattie.wu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
243 lines
7.8 KiB
Python
243 lines
7.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch._inductor.pattern_matcher as pm
|
|
from torch import fx
|
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
|
from torch._ops import OpOverload
|
|
|
|
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
|
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
from .fusion import empty_bf16
|
|
from .inductor_pass import enable_fake_mode
|
|
from .matcher_utils import MatcherSiluAndMul
|
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
|
|
|
logger = init_logger(__name__)
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
|
|
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
|
AITER_RMS_ADD_GROUP_QUANT_OP = (
|
|
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
|
)
|
|
|
|
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
|
|
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
|
|
|
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
|
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
|
|
|
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
|
|
|
|
|
class AiterRMSFp8GroupQuantPattern:
|
|
"""
|
|
This pattern fuses aiter rms_norm & group fp8 quant custom
|
|
ops into an aiter rms_norm_group_fp8_quant op.
|
|
"""
|
|
|
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
|
self.epsilon = epsilon
|
|
self.quant_dtype = quant_dtype
|
|
self.quant_op = quant_op
|
|
|
|
def register(self, pm_pass: PatternMatcherPass):
|
|
def pattern(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
):
|
|
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
|
|
|
|
at2 = self.quant_op(at1, 128)
|
|
|
|
return at2[0], at2[1]
|
|
|
|
def replacement(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
):
|
|
at = AITER_RMS_GROUP_QUANT_OP(
|
|
x=input,
|
|
weight=weight,
|
|
variance_epsilon=self.epsilon,
|
|
group_size=128,
|
|
)
|
|
|
|
return at[0], at[1]
|
|
|
|
inputs = [
|
|
empty_bf16(5, 4), # input
|
|
empty_bf16(1, 5), # weight
|
|
]
|
|
|
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
|
|
|
|
|
class AiterFusedAddRMSFp8GroupQuantPattern:
|
|
"""
|
|
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
|
into a aiter rms_norm_with_add_group_fp8_quant op.
|
|
"""
|
|
|
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
|
self.epsilon = epsilon
|
|
self.quant_dtype = quant_dtype
|
|
self.quant_op = quant_op
|
|
|
|
def register(self, pm_pass: PatternMatcherPass):
|
|
def pattern(
|
|
input: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
):
|
|
at1 = AITER_RMS_ADD_OP(
|
|
x=input,
|
|
residual=residual,
|
|
weight=weight,
|
|
variance_epsilon=self.epsilon,
|
|
)
|
|
|
|
at2 = self.quant_op(at1[0], 128)
|
|
|
|
# result, scale, residual
|
|
return at2[0], at2[1], at1[1]
|
|
|
|
def replacement(
|
|
input: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
):
|
|
at = AITER_RMS_ADD_GROUP_QUANT_OP(
|
|
x=input,
|
|
residual=residual,
|
|
weight=weight,
|
|
variance_epsilon=self.epsilon,
|
|
group_size=128,
|
|
)
|
|
|
|
# result, scale, residual
|
|
return at[0], at[1], at[2]
|
|
|
|
inputs = [
|
|
empty_bf16(5, 4), # input
|
|
empty_bf16(5, 4), # residual
|
|
empty_bf16(1, 5), # weight
|
|
]
|
|
|
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
|
|
|
|
|
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
|
"""
|
|
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
|
It also supports fused_add_rms_norm.
|
|
"""
|
|
|
|
@enable_fake_mode
|
|
def __init__(self, config: VllmConfig):
|
|
super().__init__(config)
|
|
|
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
|
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
|
|
)
|
|
|
|
# Make sure fused add patterns are before simple rms norm,
|
|
# as the latter is a subset of the former in torch ops
|
|
for epsilon in [1e-5, 1e-6]:
|
|
# Fuse rms_norm + dynamic group fp8 quant
|
|
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
|
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
|
|
self.patterns
|
|
)
|
|
|
|
AiterFusedAddRMSFp8GroupQuantPattern(
|
|
epsilon, FP8_DTYPE, quant_op
|
|
).register(self.patterns)
|
|
|
|
self.dump_patterns(config, self.patterns)
|
|
|
|
@VllmInductorPass.time_and_log
|
|
def __call__(self, graph: fx.Graph):
|
|
self.matched_count = self.patterns.apply(graph)
|
|
logger.debug("Replaced %s patterns", self.matched_count)
|
|
|
|
def uuid(self) -> Any:
|
|
fusion_patterns = [
|
|
AiterRMSFp8GroupQuantPattern,
|
|
AiterFusedAddRMSFp8GroupQuantPattern,
|
|
]
|
|
return self.hash_source(self, *fusion_patterns)
|
|
|
|
|
|
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
|
"""
|
|
This pattern fuses aiter silu_and_mul & group fp8 quant custom
|
|
ops into an aiter silu_and_mul_group_fp8_quant op.
|
|
"""
|
|
|
|
def __init__(self, quant_op: OpOverload):
|
|
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
|
self.quant_op = quant_op
|
|
|
|
def register(self, pm_pass: PatternMatcherPass):
|
|
def pattern(
|
|
input: torch.Tensor,
|
|
):
|
|
at1 = self.silu_and_mul_matcher(input)
|
|
at2 = self.quant_op(at1, 128)
|
|
return at2[0], at2[1]
|
|
|
|
def replacement(
|
|
input: torch.Tensor,
|
|
):
|
|
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
|
return at[0], at[1]
|
|
|
|
inputs = [
|
|
self.silu_and_mul_matcher.inputs()[0],
|
|
]
|
|
|
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
|
|
|
|
|
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
|
"""
|
|
This pass fuses a pre-defined set of custom ops into fused ops.
|
|
It uses the torch pattern matcher to find the patterns and replace them.
|
|
|
|
Because patterns can only be registered once, the pass is a singleton.
|
|
This will be addressed in a future version of PyTorch:
|
|
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
|
"""
|
|
|
|
@enable_fake_mode
|
|
def __init__(self, config: VllmConfig):
|
|
super().__init__(config)
|
|
|
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
|
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
|
)
|
|
|
|
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
|
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
|
|
|
self.dump_patterns(config, self.patterns)
|
|
|
|
@VllmInductorPass.time_and_log
|
|
def __call__(self, graph: torch.fx.Graph):
|
|
self.matched_count = self.patterns.apply(graph)
|
|
logger.debug("Replaced %s patterns", self.matched_count)
|
|
|
|
def uuid(self):
|
|
fusion_patterns = [
|
|
ActivationQuantPattern,
|
|
AiterSiluMulFp8GroupQuantPattern,
|
|
]
|
|
return VllmInductorPass.hash_source(self, *fusion_patterns)
|