vllm/vllm/compilation/backends.py
Russell Bryant 776dbd74f1
[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2024-10-16 22:55:59 +00:00

270 lines
11 KiB
Python

import copy
import operator
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.fx as fx
from vllm.logger import init_logger
from .compile_context import get_compile_context
from .levels import CompilationLevel
logger = init_logger(__name__)
def fix_functionalization(graph: fx.Graph):
"""
Rewrite the graph module to replace the pattern involving
torch._higher_order_ops.auto_functionalize.auto_functionalized
with a direct call to the inplace custom op.
# TODO: check if PyTorch nightly has fixed this issue
"""
# debug code, if we want to see the graph before the transformation
# with open("before.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
nodes_to_remove = []
for node in graph.nodes:
# Identify the auto_functionalized node
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
if node.args[0] == torch.ops._C.rotary_embedding.default:
# manual replace for rotary_embedding
# Now, collect the arguments
kwargs = node.kwargs
query = kwargs['query']
mm_node = query.args[0].args[0]
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(torch.ops._C.rotary_embedding.default,
kwargs=kwargs)
# Remove the auto_functionalized node
# Since the node may have outputs, we need to handle its users
# Replace uses of the outputs (getitem nodes) with mm_node
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
for getitem_user in list(user.users):
if (getitem_user.op == 'call_function'
and getitem_user.target
== torch.ops.aten.slice_scatter.default):
# Replace the uses of slice_scatter node
# with mm_node
getitem_user.replace_all_uses_with(mm_node)
nodes_to_remove.append(getitem_user)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
# manual replace for fused_add_rms_norm
# this is the most effective optimization for llama
# failing to do this will result in many unnecessary copies
kwargs = node.kwargs
input = kwargs['input']
residual = kwargs['residual']
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
if user.args[1] == 1:
replace_node = input
elif user.args[1] == 2:
replace_node = residual
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.rms_norm.default:
# manual replace for rms_norm
kwargs = node.kwargs
input = kwargs['input']
out = kwargs['out']
weight = kwargs['weight']
epsilon = kwargs['epsilon']
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.rms_norm.default,
args=(out, input, weight, epsilon),
)
replace_node = out
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.silu_and_mul.default:
# manual replace for silu_and_mul
kwargs = node.kwargs
input = kwargs['input']
out = kwargs['out']
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.silu_and_mul.default,
args=(out, input),
)
replace_node = out
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
# Remove the nodes all at once
for node in nodes_to_remove:
graph.erase_node(node)
# debug code, if we want to see the graph after the transformation
# with open("after.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
def wrap_inductor(graph, example_inputs, additional_inductor_config):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx
if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if current_config['post_grad_custom_post_pass'] is not None:
logger.warning(
"post_grad_custom_post_pass is already set in the config. "
"Overwriting it with the fix_functionalization")
current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config)
def vllm_backend(
graph,
example_inputs,
additional_inductor_config: Optional[Dict] = None) -> Callable:
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
# flags for all the seen shapes, whether we need to specialize
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
# if we need to specialize, the compiled graph for that shape
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
logger.info("Compiling a graph for general shapes")
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
additional_inductor_config)
# TODO: Dynamo does not pass all dynamic shapes.
# Need to investigate why. It works now because all the dynamic
# shapes have the same value, and either of them can be used.
sym_shape_indices = [
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
]
first_run = True
# this is the function we return to Dynamo to run finally
def compiled_graph_wrapper(*args):
runtime_shapes: Tuple[int,
...] = tuple(args[i] for i in sym_shape_indices)
nonlocal first_run
nonlocal runtime_shapes_to_compile_flags
nonlocal runtime_shapes_to_compiled_graph
if first_run:
# the first compilation is for profiling, we directly run it
first_run = False
return graph_for_symbolic_shape(*args)
if runtime_shapes not in runtime_shapes_to_compile_flags:
# we haven't seen this shape before
# query if we need to specialize for this shape
# we only specialize for the first dimension.
# TODO: investigate if any model needs to specialize
# beyond the first dimension
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
0] in sizes_to_specialize
if not runtime_shapes_to_compile_flags[runtime_shapes]:
# we don't need to specialize for this shape
return graph_for_symbolic_shape(*args)
if runtime_shapes not in runtime_shapes_to_compiled_graph:
# we need to specialize for this shape, and we haven't compiled
# compile the graph for this shape
logger.info("Compiling a graph for shapes %s", runtime_shapes)
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
graph, args, additional_inductor_config)
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
return compiled_graph_wrapper
def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend_str = "eager"
return backend_str
assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"
from vllm.compilation.backends import vllm_backend
from vllm.plugins import get_inductor_additional_configs
additional_configs = get_inductor_additional_configs()
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
if "max_autotune" in additional_configs and not additional_configs[
"max_autotune"]:
logger.warning(
"max_autotune is disabled, but is overridden by level %s",
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
additional_configs['max_autotune'] = True
from functools import partial
backend = partial(vllm_backend,
additional_inductor_config=additional_configs)
return backend