mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 10:15:23 +08:00
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
371 lines
12 KiB
Python
371 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any, NamedTuple
|
|
|
|
import torch
|
|
import torch._inductor.pattern_matcher as pm
|
|
from torch import fx
|
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
|
from torch._ops import OpOverload
|
|
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
GroupShape,
|
|
QuantKey,
|
|
ScaleDesc,
|
|
kFp8DynamicTensorSym,
|
|
kFp8DynamicTokenSym,
|
|
kFp8StaticTensorSym,
|
|
kNvfp4Quant,
|
|
kStaticTensorScale,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
|
|
from .inductor_pass import enable_fake_mode
|
|
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
|
|
|
logger = init_logger(__name__)
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
FP4_DTYPE = torch.uint8
|
|
|
|
|
|
def empty_bf16(*args, **kwargs):
|
|
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
|
|
|
|
|
def empty_fp32(*args, **kwargs):
|
|
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
|
|
|
|
|
def empty_i32(*args, **kwargs):
|
|
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
|
|
|
|
|
RMS_OP = torch.ops._C.rms_norm.default
|
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
|
|
|
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
|
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
|
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
|
}
|
|
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
|
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
|
|
|
|
|
class FusedRMSQuantKey(NamedTuple):
|
|
"""
|
|
Named tuple for identifying the type of RMSNorm + quant fusion.
|
|
quant: type of quantization
|
|
fused_add: does the op also perform the residual add
|
|
"""
|
|
|
|
quant: QuantKey
|
|
fused_add: bool
|
|
|
|
def __str__(self):
|
|
return (
|
|
f"FusedQuantKey({self.quant}, with"
|
|
f"{'' if self.fused_add else 'out'} residual)"
|
|
)
|
|
|
|
|
|
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
|
FusedRMSQuantKey(
|
|
kFp8StaticTensorSym, False
|
|
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
|
FusedRMSQuantKey(
|
|
kFp8StaticTensorSym, True
|
|
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
|
FusedRMSQuantKey(
|
|
kFp8DynamicTokenSym, False
|
|
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
|
FusedRMSQuantKey(
|
|
kFp8DynamicTokenSym, True
|
|
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
|
}
|
|
|
|
|
|
class RMSNormQuantPattern:
|
|
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
|
self.epsilon = epsilon
|
|
self.quant_dtype = key.quant.dtype
|
|
config = get_current_vllm_config()
|
|
self.model_dtype = config.model_config.dtype if config.model_config else None
|
|
|
|
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
|
self.FUSED_OP = FUSED_OPS[key]
|
|
|
|
self.rmsnorm_matcher = (
|
|
MatcherRMSNorm(epsilon)
|
|
if not key.fused_add
|
|
else MatcherFusedAddRMSNorm(epsilon)
|
|
)
|
|
self.quant_matcher = MatcherQuantFP8(key.quant)
|
|
|
|
|
|
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
|
fused_key = FusedRMSQuantKey(
|
|
fused_add=False,
|
|
quant=QuantKey(
|
|
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
|
),
|
|
)
|
|
super().__init__(epsilon, fused_key)
|
|
|
|
def register(self, pm_pass: PatternMatcherPass):
|
|
# Cannot use methods, as the self argument affects tracing
|
|
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
|
result_rms = self.rmsnorm_matcher(input, weight)
|
|
return self.quant_matcher(result_rms, scale)[0]
|
|
|
|
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
|
# In case we're matching native rms-norm, conversions might be
|
|
# optimized out. We convert here just to be safe.
|
|
input = input.to(dtype=self.model_dtype)
|
|
|
|
result = torch.empty(
|
|
input.shape, device=input.device, dtype=self.quant_dtype
|
|
)
|
|
at = auto_functionalized(
|
|
self.FUSED_OP,
|
|
result=result,
|
|
input=input,
|
|
weight=weight,
|
|
scale=scale,
|
|
epsilon=self.epsilon,
|
|
)
|
|
|
|
# result
|
|
return at[1]
|
|
|
|
inputs = [
|
|
# input, weight
|
|
*self.rmsnorm_matcher.inputs(),
|
|
self.quant_matcher.inputs()[1], # scale
|
|
]
|
|
pattern(*inputs)
|
|
|
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
|
|
|
|
|
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
|
key = FusedRMSQuantKey(
|
|
fused_add=True,
|
|
quant=QuantKey(
|
|
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
|
),
|
|
)
|
|
super().__init__(epsilon, key)
|
|
|
|
def register(self, pm_pass: PatternMatcherPass):
|
|
def pattern(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
):
|
|
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
|
result, _ = self.quant_matcher(result_rms, scale)
|
|
|
|
return result, residual
|
|
|
|
def replacement(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
):
|
|
# In case we're matching native rms-norm, conversions might be
|
|
# optimized out. We convert here just to be safe.
|
|
input = input.to(dtype=self.model_dtype)
|
|
|
|
result = torch.empty_like(input, dtype=self.quant_dtype)
|
|
at = auto_functionalized(
|
|
self.FUSED_OP,
|
|
result=result,
|
|
input=input,
|
|
residual=residual,
|
|
weight=weight,
|
|
scale=scale,
|
|
epsilon=self.epsilon,
|
|
)
|
|
|
|
# result, residual
|
|
return at[1], at[2]
|
|
|
|
inputs = [
|
|
# input, weight, residual
|
|
*self.rmsnorm_matcher.inputs(),
|
|
self.quant_matcher.inputs()[1], # scale
|
|
]
|
|
|
|
pm.register_replacement(
|
|
pattern,
|
|
replacement,
|
|
inputs,
|
|
pm.fwd_only,
|
|
pm_pass,
|
|
)
|
|
|
|
|
|
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|
def __init__(
|
|
self,
|
|
epsilon: float,
|
|
quant_dtype: torch.dtype,
|
|
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
|
symmetric=True,
|
|
):
|
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
|
key = FusedRMSQuantKey(
|
|
fused_add=False,
|
|
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
|
)
|
|
super().__init__(epsilon, key)
|
|
|
|
def register(self, pm_pass: PatternMatcherPass):
|
|
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
|
result_rms = self.rmsnorm_matcher(input, weight)
|
|
# result, scale
|
|
return self.quant_matcher(result_rms)
|
|
|
|
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
|
# In case we're matching native rms-norm, conversions might be
|
|
# optimized out. We convert here just to be safe.
|
|
input = input.to(dtype=self.model_dtype)
|
|
|
|
result = torch.empty_like(input, dtype=self.quant_dtype)
|
|
scale = self.quant_matcher.make_scale(input)
|
|
at = auto_functionalized(
|
|
self.FUSED_OP,
|
|
result=result,
|
|
input=input,
|
|
weight=weight,
|
|
scale=scale,
|
|
epsilon=self.epsilon,
|
|
scale_ub=None,
|
|
residual=None,
|
|
)
|
|
|
|
# result, scale
|
|
return at[1], at[2]
|
|
|
|
pm.register_replacement(
|
|
pattern,
|
|
replacement,
|
|
self.rmsnorm_matcher.inputs(),
|
|
pm.fwd_only,
|
|
pm_pass,
|
|
)
|
|
|
|
|
|
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|
def __init__(
|
|
self,
|
|
epsilon: float,
|
|
quant_dtype: torch.dtype,
|
|
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
|
symmetric=True,
|
|
):
|
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
|
key = FusedRMSQuantKey(
|
|
fused_add=True,
|
|
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
|
)
|
|
super().__init__(epsilon, key)
|
|
|
|
def register(self, pm_pass: PatternMatcherPass):
|
|
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
|
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
|
result, scale = self.quant_matcher(result_rms)
|
|
|
|
return result, residual, scale
|
|
|
|
def replacement(
|
|
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
|
):
|
|
# In case we're matching native rms-norm, conversions might be
|
|
# optimized out. We convert here just to be safe.
|
|
input = input.to(dtype=self.model_dtype)
|
|
|
|
result = torch.empty_like(input, dtype=self.quant_dtype)
|
|
scale = self.quant_matcher.make_scale(input)
|
|
at = auto_functionalized(
|
|
self.FUSED_OP,
|
|
result=result,
|
|
input=input,
|
|
weight=weight,
|
|
scale=scale,
|
|
epsilon=self.epsilon,
|
|
scale_ub=None,
|
|
residual=residual,
|
|
)
|
|
|
|
# result, residual, scale
|
|
return at[1], at[3], at[2]
|
|
|
|
pm.register_replacement(
|
|
pattern,
|
|
replacement,
|
|
self.rmsnorm_matcher.inputs(),
|
|
pm.fwd_only,
|
|
pm_pass,
|
|
)
|
|
|
|
|
|
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
|
"""
|
|
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
|
It also supports fused_add_rms_norm.
|
|
"""
|
|
|
|
@enable_fake_mode
|
|
def __init__(self, config: VllmConfig):
|
|
super().__init__(config)
|
|
|
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
|
pass_name="rmsnorm_quant_fusion_pass"
|
|
)
|
|
|
|
# Make sure fused add patterns are before simple rms norm,
|
|
# as the latter is a subset of the former in torch ops
|
|
for epsilon in [1e-5, 1e-6]:
|
|
# Fuse fused_add_rms_norm + static fp8 quant
|
|
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
|
self.patterns
|
|
)
|
|
|
|
# Fuse rms_norm + static fp8 quant
|
|
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
|
|
|
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
|
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
|
self.patterns
|
|
)
|
|
|
|
# Fuse rms_norm + dynamic per-token fp8 quant
|
|
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
|
|
|
self.dump_patterns(config, self.patterns)
|
|
|
|
@VllmInductorPass.time_and_log
|
|
def __call__(self, graph: fx.Graph):
|
|
self.matched_count = self.patterns.apply(graph)
|
|
logger.debug("Replaced %s patterns", self.matched_count)
|
|
|
|
def uuid(self) -> Any:
|
|
return self.hash_source(
|
|
self,
|
|
RMSNormQuantPattern,
|
|
RMSNormStaticQuantPattern,
|
|
RMSNormDynamicQuantPattern,
|
|
FusedAddRMSNormStaticQuantPattern,
|
|
FusedAddRMSNormDynamicQuantPattern,
|
|
)
|