vllm/vllm/compilation/fusion.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

618 lines
24 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch._inductor.pattern_matcher as pm
# TODO(luka) use vllm.utils once #10836 landed
from compressed_tensors.quantization import FP8_DTYPE
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from .fx_utils import find_getitem_maybe
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
per_tensor: per-tensor quantization if True, per-token if False
symmetric: symmetric if True, asymmetric if False
"""
dtype: torch.dtype
static: bool
per_tensor: bool = True
symmetric: bool = True
def __str__(self):
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'per_tensor' if self.per_tensor else 'per_token'},"
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
QUANT_OPS: Dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
}
class FusedRMSQuantKey(NamedTuple):
"""
Named tuple for identifying the type of RMSNorm + quant fusion.
quant: type of quantization
fused_add: does the op also perform the residual add
"""
quant: QuantKey
fused_add: bool
def __str__(self):
return (f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)")
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
}
class QuantMultiOutputMatch(MultiOutputMatch):
def __init__(self, match: pm.Match, quant_op, fused_op):
super().__init__(match)
assert isinstance(quant_op, OpOverload)
assert isinstance(fused_op, OpOverload)
self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
int]],
**kwargs):
"""
This utility function inserts an auto-functionalized node for FUSED_OP.
It also correctly sets its meta value and rebinds the users of the
unfused nodes to use the fused node instead.
:param fused_return_mapping: A dictionary, mapping from getitem indices
of the fused node result to a tuple of the old node and a getitem index.
:param kwargs: kwargs that get directly forwarded to the auto_fn node
Example:
If we want to replace this graph:
_, x1, x2 = auto_fn(op1)
_, y1, y2 = auto_fn(op2)
with
_, x1, y2, x2 = auto_fn(FUSED_OP)
we would call:
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
Note that the 0th element is None for auto-functionalized in-place ops.
Hence, others appear 1-indexed.
"""
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
indices = fused_return_mapping.keys()
getitem_nodes = self.insert_getitems(fused_node, indices)
# Prepare the meta value, use a list so it's mutable
meta_val = [None] * (max(indices) + 1)
# Iterate through elements of the tuple produced by fused_node
for idx, getitem_node in zip(indices, getitem_nodes):
old_node, old_idx = fused_return_mapping[idx]
# If the old value was never used, the old_getitem might not exist
old_getitem = find_getitem_maybe(old_node, old_idx)
if old_getitem is not None:
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
old_getitem.replace_all_uses_with(getitem_node)
getitem_node.meta["val"] = old_getitem.meta["val"]
# Extract the appropriate meta value
# It is present even if the getitem node does not exist
meta_val[idx] = old_node.meta["val"][old_idx]
# Fix the meta value on the new fused node
fused_node.meta["val"] = tuple(meta_val)
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
assert key.quant in QUANT_OPS, \
f"unsupported quantization scheme {key.quant}"
self.QUANT_OP = QUANT_OPS[key.quant]
assert key in FUSED_OPS, \
f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale)
# result
return at2[1]
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon)
# result
return at[1]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
pm_pass)
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(self.QUANT_OP,
result=result,
input=at[1],
scale=scale)
# result, residual
return at1[1], at[2]
def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon)
# result, residual
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and residual.
# The auto_fn node returns a tuple of (None, result, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
# 0 is always None
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
**kwargs)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool,
symmetric=True):
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale,
scale_ub=None)
# result, scale
return at2[1], at2[2]
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None)
# result, scale
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 1
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and scale.
# The auto_fn node returns a tuple of (None, result, scale).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
del kwargs["result_rms"] # not used in the fused op
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
residual=None, # not used but required
**kwargs)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool = True,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(self.QUANT_OP,
result=result,
input=at[1],
scale=scale,
scale_ub=None)
# result, residual, scale
return at1[1], at[2], at1[2]
def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual)
# result, residual, scale
return at[1], at[3], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract result, scale, and residual.
# The auto_fn node returns a tuple (None, result, scale, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
# residual_node_new = at[3]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
fused_return_mapping = {
1: (quant_node, 1), # result
2: (quant_node, 2), # scale
3: (rms_node, 2), # residual
}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
**kwargs)
class FusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
It also manually processes multi-output matches, as those are broken in
the torch pattern matcher.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: CompilationConfig.PassConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
initialization is not repeated.
"""
if cls._instance is None:
cls._instance = FusionPass(config)
else:
cls._instance.config = config
return cls._instance
def __init__(self, config: CompilationConfig.PassConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)
self.matches: List[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass")
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
per_tensor=False).register(
self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon,
FP8_DTYPE,
per_tensor=False).register(
self.patterns,
self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def record_match(self, match: MultiOutputMatch) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
self.matches.append(match)
# Return False to prevent automatic replacement.
return False
def process_matches(self, graph: fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
See MultiOutputMatch for more details.
"""
for match in self.matches:
match.process()
# Finally, remove matched nodes
graph.eliminate_dead_code()
assert all(node not in graph.nodes for match in self.matches
for node in match.match.nodes)
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_fusion")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_pattern_match")
# Manually process multi-output matches (and run DCE)
self.process_matches(graph)
logger.debug("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_fusion")
self.matches.clear()
self.end_and_log()