mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 03:16:31 +08:00
237 lines
9.6 KiB
Python
237 lines
9.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import operator
|
|
from collections.abc import Iterable
|
|
|
|
import torch
|
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
from .fx_utils import is_func
|
|
from .vllm_inductor_pass import VllmInductorPass
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class FixFunctionalizationPass(VllmInductorPass):
|
|
"""
|
|
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
|
After this pass, DCE (dead-code elimination) should never be run,
|
|
as de-functionalized nodes may appear as dead code.
|
|
|
|
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
|
"""
|
|
|
|
@VllmInductorPass.time_and_log
|
|
def __call__(self, graph: torch.fx.Graph):
|
|
# XPU does not support auto-functionalization yet.
|
|
# Will enable this when switch to vllm-xpu-kernels.
|
|
if current_platform.is_xpu():
|
|
logger.debug(
|
|
"XPU platform does not support fix functionalizationpass currently."
|
|
)
|
|
return
|
|
|
|
self.nodes_to_remove: list[torch.fx.Node] = []
|
|
count = 0
|
|
for node in graph.nodes:
|
|
if not is_func(node, auto_functionalized):
|
|
continue # Avoid deep if-elif nesting
|
|
|
|
kwargs = node.kwargs
|
|
at_target = node.args[0]
|
|
|
|
if at_target == torch.ops._C.rotary_embedding.default:
|
|
query = kwargs["query"]
|
|
key = kwargs["key"]
|
|
getitem_nodes = self.getitem_users(node)
|
|
|
|
if (
|
|
is_func(query, operator.getitem)
|
|
and is_func(key, operator.getitem)
|
|
and query.args[0] == key.args[0]
|
|
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
|
and all(
|
|
is_func(user, torch.ops.aten.slice_scatter.default)
|
|
for getitem_node in getitem_nodes.values()
|
|
for user in getitem_node.users
|
|
)
|
|
):
|
|
# Pattern where query and key are slices of an mm_node.
|
|
# While functionalized, results at [1] and [2] are scattered
|
|
# back into mm_node. So after de-functionalization, we can
|
|
# just use mm_node directly.
|
|
|
|
mm_node = query.args[0].args[0]
|
|
for user in getitem_nodes.values():
|
|
for user_of_getitem in user.users:
|
|
if is_func(
|
|
user_of_getitem, torch.ops.aten.slice_scatter.default
|
|
):
|
|
user_of_getitem.replace_all_uses_with(mm_node)
|
|
self._remove(user_of_getitem)
|
|
self._remove(user)
|
|
|
|
self.insert_defunctionalized(graph, node)
|
|
self._remove(node)
|
|
|
|
else:
|
|
# Directly replace the auto_functionalize(rotary_embedding)
|
|
# with the inplace rotary_embedding. In theory, we shouldn't
|
|
# do this blindly, but in practice in vLLM it's ok. The best
|
|
# solution is to use auto_functionalization_v2 and then use
|
|
# inductor's builtin defunctionalization (reinplacing) pass.
|
|
mutated_args = {1: "query", 2: "key"}
|
|
self.defunctionalize(graph, node, mutated_args)
|
|
|
|
# rms_norm replacements avoid the most copies for LLaMa.
|
|
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
|
mutated_args = {1: "input", 2: "residual"}
|
|
self.defunctionalize(graph, node, mutated_args)
|
|
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
|
mutated_args = {1: "result", 2: "residual"}
|
|
self.defunctionalize(graph, node, mutated_args)
|
|
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
|
mutated_args = {1: "result", 2: "scale", 3: "residual"}
|
|
self.defunctionalize(graph, node, mutated_args)
|
|
elif at_target in [
|
|
torch.ops._C.rms_norm.default,
|
|
torch.ops._C.rms_norm_static_fp8_quant.default,
|
|
]:
|
|
mutated_args = {1: "result"}
|
|
self.defunctionalize(graph, node, mutated_args)
|
|
# For some reason we need to specify the args for both
|
|
# silu_and_mul and silu_and_mul_quant. The kwargs
|
|
# pathway gets the wrong answer.
|
|
elif at_target == torch.ops._C.silu_and_mul.default:
|
|
mutated_args = {1: "result"}
|
|
self.defunctionalize(
|
|
graph, node, mutated_args, args=("result", "input")
|
|
)
|
|
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
|
mutated_args = {1: "result"}
|
|
self.defunctionalize(
|
|
graph, node, mutated_args, args=("result", "input", "scale")
|
|
)
|
|
elif (
|
|
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
|
|
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
|
|
):
|
|
mutated_args = {1: "result", 2: "result_block_scale"}
|
|
self.defunctionalize(
|
|
graph,
|
|
node,
|
|
mutated_args,
|
|
args=(
|
|
"result",
|
|
"result_block_scale",
|
|
"input",
|
|
"input_global_scale",
|
|
),
|
|
)
|
|
else:
|
|
continue # skip the count
|
|
|
|
count += 1
|
|
|
|
self.dump_graph(graph, "before_cleanup")
|
|
|
|
# Remove the nodes all at once
|
|
count_removed = len(self.nodes_to_remove)
|
|
for node in self.nodes_to_remove:
|
|
graph.erase_node(node)
|
|
|
|
logger.debug(
|
|
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
|
)
|
|
self.nodes_to_remove.clear()
|
|
|
|
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]):
|
|
"""
|
|
Stage a node (or nodes) for removal at the end of the pass.
|
|
"""
|
|
if isinstance(node_or_nodes, torch.fx.Node):
|
|
self.nodes_to_remove.append(node_or_nodes)
|
|
else:
|
|
self.nodes_to_remove.extend(node_or_nodes)
|
|
|
|
def defunctionalize(
|
|
self,
|
|
graph: torch.fx.Graph,
|
|
node: torch.fx.Node,
|
|
mutated_args: dict[int, torch.fx.Node | str],
|
|
args: tuple[torch.fx.Node | str, ...] | None = None,
|
|
):
|
|
"""
|
|
De-functionalize a node by replacing it with a call to the original.
|
|
It also replaces the getitem users with the mutated arguments.
|
|
See replace_users_with_mutated_args and insert_defunctionalized.
|
|
"""
|
|
self.replace_users_with_mutated_args(node, mutated_args)
|
|
self.insert_defunctionalized(graph, node, args=args)
|
|
self._remove(node)
|
|
|
|
def replace_users_with_mutated_args(
|
|
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
|
|
):
|
|
"""
|
|
Replace all getitem users of the auto-functionalized node with the
|
|
mutated arguments.
|
|
:param node: The auto-functionalized node
|
|
:param mutated_args: The mutated arguments, indexed by getitem index.
|
|
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
|
"""
|
|
for idx, user in self.getitem_users(node).items():
|
|
arg = mutated_args[idx]
|
|
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
|
user.replace_all_uses_with(arg)
|
|
self._remove(user)
|
|
|
|
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
|
"""
|
|
Returns the operator.getitem users of the auto-functionalized node,
|
|
indexed by the index they are getting.
|
|
"""
|
|
users = {}
|
|
for user in node.users:
|
|
if is_func(user, operator.getitem):
|
|
idx = user.args[1]
|
|
users[idx] = user
|
|
return users
|
|
|
|
def insert_defunctionalized(
|
|
self,
|
|
graph: torch.fx.Graph,
|
|
node: torch.fx.Node,
|
|
args: tuple[torch.fx.Node | str, ...] | None = None,
|
|
):
|
|
"""
|
|
Insert a new defunctionalized node into the graph before node.
|
|
If one of the kwargs is 'out', provide args directly,
|
|
as node.kwargs cannot be used.
|
|
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
|
|
|
:param graph: Graph to insert the defunctionalized node into
|
|
:param node: The auto-functionalized node to defunctionalize
|
|
:param args: If we cannot use kwargs, specify args directly.
|
|
If an arg is a string, `node.kwargs[arg]` is used.
|
|
""" # noqa: E501
|
|
assert is_func(node, auto_functionalized), (
|
|
f"node must be auto-functionalized, is {node} instead"
|
|
)
|
|
|
|
# Create a new call to the original function
|
|
with graph.inserting_before(node):
|
|
function = node.args[0]
|
|
if args is None:
|
|
graph.call_function(function, kwargs=node.kwargs)
|
|
else:
|
|
# Args passed as strings refer to items in node.kwargs
|
|
args = tuple(
|
|
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
|
)
|
|
graph.call_function(function, args=args)
|