mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 10:26:20 +08:00
270 lines
11 KiB
Python
270 lines
11 KiB
Python
import copy
|
|
import operator
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.fx as fx
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
from .compile_context import get_compile_context
|
|
from .levels import CompilationLevel
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def fix_functionalization(graph: fx.Graph):
|
|
"""
|
|
Rewrite the graph module to replace the pattern involving
|
|
torch._higher_order_ops.auto_functionalize.auto_functionalized
|
|
with a direct call to the inplace custom op.
|
|
|
|
# TODO: check if PyTorch nightly has fixed this issue
|
|
"""
|
|
|
|
# debug code, if we want to see the graph before the transformation
|
|
# with open("before.py", "w") as f:
|
|
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
|
|
|
nodes_to_remove = []
|
|
|
|
for node in graph.nodes:
|
|
# Identify the auto_functionalized node
|
|
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
|
|
if node.args[0] == torch.ops._C.rotary_embedding.default:
|
|
# manual replace for rotary_embedding
|
|
|
|
# Now, collect the arguments
|
|
kwargs = node.kwargs
|
|
|
|
query = kwargs['query']
|
|
mm_node = query.args[0].args[0]
|
|
|
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
|
with graph.inserting_before(node):
|
|
# just insert the call to the custom op
|
|
# NOTE: don't run dead code elimination,
|
|
# otherwise this op will be removed
|
|
graph.call_function(torch.ops._C.rotary_embedding.default,
|
|
kwargs=kwargs)
|
|
|
|
# Remove the auto_functionalized node
|
|
# Since the node may have outputs, we need to handle its users
|
|
# Replace uses of the outputs (getitem nodes) with mm_node
|
|
for user in list(node.users):
|
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
|
# Remove the getitem node
|
|
for getitem_user in list(user.users):
|
|
if (getitem_user.op == 'call_function'
|
|
and getitem_user.target
|
|
== torch.ops.aten.slice_scatter.default):
|
|
# Replace the uses of slice_scatter node
|
|
# with mm_node
|
|
getitem_user.replace_all_uses_with(mm_node)
|
|
nodes_to_remove.append(getitem_user)
|
|
nodes_to_remove.append(user)
|
|
nodes_to_remove.append(node)
|
|
|
|
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
|
|
# manual replace for fused_add_rms_norm
|
|
# this is the most effective optimization for llama
|
|
# failing to do this will result in many unnecessary copies
|
|
|
|
kwargs = node.kwargs
|
|
|
|
input = kwargs['input']
|
|
residual = kwargs['residual']
|
|
|
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
|
with graph.inserting_before(node):
|
|
# just insert the call to the custom op
|
|
# NOTE: don't run dead code elimination,
|
|
# otherwise this op will be removed
|
|
graph.call_function(
|
|
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
|
|
|
|
for user in list(node.users):
|
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
|
# Remove the getitem node
|
|
if user.args[1] == 1:
|
|
replace_node = input
|
|
elif user.args[1] == 2:
|
|
replace_node = residual
|
|
user.replace_all_uses_with(replace_node)
|
|
nodes_to_remove.append(user)
|
|
nodes_to_remove.append(node)
|
|
|
|
elif node.args[0] == torch.ops._C.rms_norm.default:
|
|
# manual replace for rms_norm
|
|
|
|
kwargs = node.kwargs
|
|
|
|
input = kwargs['input']
|
|
out = kwargs['out']
|
|
weight = kwargs['weight']
|
|
epsilon = kwargs['epsilon']
|
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
|
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
|
with graph.inserting_before(node):
|
|
# just insert the call to the custom op
|
|
# NOTE: don't run dead code elimination,
|
|
# otherwise this op will be removed
|
|
graph.call_function(
|
|
torch.ops._C.rms_norm.default,
|
|
args=(out, input, weight, epsilon),
|
|
)
|
|
|
|
replace_node = out
|
|
|
|
for user in list(node.users):
|
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
|
user.replace_all_uses_with(replace_node)
|
|
nodes_to_remove.append(user)
|
|
nodes_to_remove.append(node)
|
|
|
|
elif node.args[0] == torch.ops._C.silu_and_mul.default:
|
|
# manual replace for silu_and_mul
|
|
|
|
kwargs = node.kwargs
|
|
|
|
input = kwargs['input']
|
|
out = kwargs['out']
|
|
|
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
|
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
|
with graph.inserting_before(node):
|
|
# just insert the call to the custom op
|
|
# NOTE: don't run dead code elimination,
|
|
# otherwise this op will be removed
|
|
graph.call_function(
|
|
torch.ops._C.silu_and_mul.default,
|
|
args=(out, input),
|
|
)
|
|
replace_node = out
|
|
|
|
for user in list(node.users):
|
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
|
user.replace_all_uses_with(replace_node)
|
|
nodes_to_remove.append(user)
|
|
nodes_to_remove.append(node)
|
|
|
|
# Remove the nodes all at once
|
|
for node in nodes_to_remove:
|
|
graph.erase_node(node)
|
|
|
|
# debug code, if we want to see the graph after the transformation
|
|
# with open("after.py", "w") as f:
|
|
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
|
|
|
|
|
def wrap_inductor(graph, example_inputs, additional_inductor_config):
|
|
from torch._inductor import config
|
|
current_config = config.shallow_copy_dict()
|
|
from torch._inductor.compile_fx import compile_fx
|
|
|
|
if additional_inductor_config is not None:
|
|
current_config.update(additional_inductor_config)
|
|
if current_config['post_grad_custom_post_pass'] is not None:
|
|
logger.warning(
|
|
"post_grad_custom_post_pass is already set in the config. "
|
|
"Overwriting it with the fix_functionalization")
|
|
current_config['post_grad_custom_post_pass'] = fix_functionalization
|
|
return compile_fx(graph, example_inputs, config_patches=current_config)
|
|
|
|
|
|
def vllm_backend(
|
|
graph,
|
|
example_inputs,
|
|
additional_inductor_config: Optional[Dict] = None) -> Callable:
|
|
|
|
context = get_compile_context()
|
|
context = copy.deepcopy(context) if context is not None else []
|
|
sizes_to_specialize: List[int] = context
|
|
|
|
# flags for all the seen shapes, whether we need to specialize
|
|
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
|
|
|
|
# if we need to specialize, the compiled graph for that shape
|
|
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
|
|
|
|
# this is the first compilation, we will compile a graph with
|
|
# dynamic shape, as the caller will mark first dimension as dynamic
|
|
logger.info("Compiling a graph for general shapes")
|
|
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
|
|
additional_inductor_config)
|
|
|
|
# TODO: Dynamo does not pass all dynamic shapes.
|
|
# Need to investigate why. It works now because all the dynamic
|
|
# shapes have the same value, and either of them can be used.
|
|
sym_shape_indices = [
|
|
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
|
|
]
|
|
|
|
first_run = True
|
|
|
|
# this is the function we return to Dynamo to run finally
|
|
def compiled_graph_wrapper(*args):
|
|
|
|
runtime_shapes: Tuple[int,
|
|
...] = tuple(args[i] for i in sym_shape_indices)
|
|
|
|
nonlocal first_run
|
|
nonlocal runtime_shapes_to_compile_flags
|
|
nonlocal runtime_shapes_to_compiled_graph
|
|
|
|
if first_run:
|
|
# the first compilation is for profiling, we directly run it
|
|
first_run = False
|
|
return graph_for_symbolic_shape(*args)
|
|
|
|
if runtime_shapes not in runtime_shapes_to_compile_flags:
|
|
# we haven't seen this shape before
|
|
# query if we need to specialize for this shape
|
|
# we only specialize for the first dimension.
|
|
# TODO: investigate if any model needs to specialize
|
|
# beyond the first dimension
|
|
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
|
|
0] in sizes_to_specialize
|
|
|
|
if not runtime_shapes_to_compile_flags[runtime_shapes]:
|
|
# we don't need to specialize for this shape
|
|
return graph_for_symbolic_shape(*args)
|
|
|
|
if runtime_shapes not in runtime_shapes_to_compiled_graph:
|
|
# we need to specialize for this shape, and we haven't compiled
|
|
# compile the graph for this shape
|
|
logger.info("Compiling a graph for shapes %s", runtime_shapes)
|
|
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
|
|
graph, args, additional_inductor_config)
|
|
|
|
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
|
|
|
|
return compiled_graph_wrapper
|
|
|
|
|
|
def select_default_backend(level: int) -> Union[str, Callable]:
|
|
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
|
backend_str = "eager"
|
|
return backend_str
|
|
assert level in [
|
|
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
|
|
], f"Invalid level {level}"
|
|
|
|
from vllm.compilation.backends import vllm_backend
|
|
from vllm.plugins import get_inductor_additional_configs
|
|
additional_configs = get_inductor_additional_configs()
|
|
|
|
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
|
|
if "max_autotune" in additional_configs and not additional_configs[
|
|
"max_autotune"]:
|
|
logger.warning(
|
|
"max_autotune is disabled, but is overridden by level %s",
|
|
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
|
|
additional_configs['max_autotune'] = True
|
|
|
|
from functools import partial
|
|
backend = partial(vllm_backend,
|
|
additional_inductor_config=additional_configs)
|
|
|
|
return backend
|