mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 11:45:52 +08:00
[torch.compile] fix functionalization (#8480)
This commit is contained in:
parent
8a0cf1ddc3
commit
a36e070dad
@ -16,7 +16,12 @@ def test_full_graph(model):
|
|||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
|
llm = LLM(model=model, enforce_eager=True)
|
||||||
enforce_eager=True,
|
|
||||||
load_format="dummy")
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
llm.generate(prompts, sampling_params)
|
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|||||||
156
vllm/compilation/backends.py
Normal file
156
vllm/compilation/backends.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import operator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx as fx
|
||||||
|
|
||||||
|
|
||||||
|
def fix_functionalization(graph: fx.Graph):
|
||||||
|
"""
|
||||||
|
Rewrite the graph module to replace the pattern involving
|
||||||
|
torch._higher_order_ops.auto_functionalize.auto_functionalized
|
||||||
|
with a direct call to the inplace custom op.
|
||||||
|
|
||||||
|
# TODO: check if PyTorch nightly has fixed this issue
|
||||||
|
"""
|
||||||
|
|
||||||
|
# debug code, if we want to see the graph before the transformation
|
||||||
|
# with open("before.py", "w") as f:
|
||||||
|
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||||
|
|
||||||
|
nodes_to_remove = []
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
# Identify the auto_functionalized node
|
||||||
|
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
|
||||||
|
if node.args[0] == torch.ops._C.rotary_embedding.default:
|
||||||
|
# manual replace for rotary_embedding
|
||||||
|
|
||||||
|
# Now, collect the arguments
|
||||||
|
kwargs = node.kwargs
|
||||||
|
|
||||||
|
query = kwargs['query']
|
||||||
|
mm_node = query.args[0].args[0]
|
||||||
|
|
||||||
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||||
|
with graph.inserting_before(node):
|
||||||
|
# just insert the call to the custom op
|
||||||
|
# NOTE: don't run dead code elimination,
|
||||||
|
# otherwise this op will be removed
|
||||||
|
graph.call_function(torch.ops._C.rotary_embedding.default,
|
||||||
|
kwargs=kwargs)
|
||||||
|
|
||||||
|
# Remove the auto_functionalized node
|
||||||
|
# Since the node may have outputs, we need to handle its users
|
||||||
|
# Replace uses of the outputs (getitem nodes) with mm_node
|
||||||
|
for user in list(node.users):
|
||||||
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||||
|
# Remove the getitem node
|
||||||
|
for getitem_user in list(user.users):
|
||||||
|
if (getitem_user.op == 'call_function'
|
||||||
|
and getitem_user.target
|
||||||
|
== torch.ops.aten.slice_scatter.default):
|
||||||
|
# Replace the uses of slice_scatter node
|
||||||
|
# with mm_node
|
||||||
|
getitem_user.replace_all_uses_with(mm_node)
|
||||||
|
nodes_to_remove.append(getitem_user)
|
||||||
|
nodes_to_remove.append(user)
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
|
||||||
|
# manual replace for fused_add_rms_norm
|
||||||
|
# this is the most effective optimization for llama
|
||||||
|
# failing to do this will result in many unnecessary copies
|
||||||
|
|
||||||
|
kwargs = node.kwargs
|
||||||
|
|
||||||
|
input = kwargs['input']
|
||||||
|
residual = kwargs['residual']
|
||||||
|
|
||||||
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||||
|
with graph.inserting_before(node):
|
||||||
|
# just insert the call to the custom op
|
||||||
|
# NOTE: don't run dead code elimination,
|
||||||
|
# otherwise this op will be removed
|
||||||
|
graph.call_function(
|
||||||
|
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
|
||||||
|
|
||||||
|
for user in list(node.users):
|
||||||
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||||
|
# Remove the getitem node
|
||||||
|
if user.args[1] == 1:
|
||||||
|
replace_node = input
|
||||||
|
elif user.args[1] == 2:
|
||||||
|
replace_node = residual
|
||||||
|
user.replace_all_uses_with(replace_node)
|
||||||
|
nodes_to_remove.append(user)
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
elif node.args[0] == torch.ops._C.rms_norm.default:
|
||||||
|
# manual replace for rms_norm
|
||||||
|
|
||||||
|
kwargs = node.kwargs
|
||||||
|
|
||||||
|
input = kwargs['input']
|
||||||
|
out = kwargs['out']
|
||||||
|
weight = kwargs['weight']
|
||||||
|
epsilon = kwargs['epsilon']
|
||||||
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||||
|
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
||||||
|
with graph.inserting_before(node):
|
||||||
|
# just insert the call to the custom op
|
||||||
|
# NOTE: don't run dead code elimination,
|
||||||
|
# otherwise this op will be removed
|
||||||
|
graph.call_function(
|
||||||
|
torch.ops._C.rms_norm.default,
|
||||||
|
args=(out, input, weight, epsilon),
|
||||||
|
)
|
||||||
|
|
||||||
|
replace_node = out
|
||||||
|
|
||||||
|
for user in list(node.users):
|
||||||
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||||
|
user.replace_all_uses_with(replace_node)
|
||||||
|
nodes_to_remove.append(user)
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
elif node.args[0] == torch.ops._C.silu_and_mul.default:
|
||||||
|
# manual replace for silu_and_mul
|
||||||
|
|
||||||
|
kwargs = node.kwargs
|
||||||
|
|
||||||
|
input = kwargs['input']
|
||||||
|
out = kwargs['out']
|
||||||
|
|
||||||
|
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||||
|
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
||||||
|
with graph.inserting_before(node):
|
||||||
|
# just insert the call to the custom op
|
||||||
|
# NOTE: don't run dead code elimination,
|
||||||
|
# otherwise this op will be removed
|
||||||
|
graph.call_function(
|
||||||
|
torch.ops._C.silu_and_mul.default,
|
||||||
|
args=(out, input),
|
||||||
|
)
|
||||||
|
replace_node = out
|
||||||
|
|
||||||
|
for user in list(node.users):
|
||||||
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||||
|
user.replace_all_uses_with(replace_node)
|
||||||
|
nodes_to_remove.append(user)
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
# Remove the nodes all at once
|
||||||
|
for node in nodes_to_remove:
|
||||||
|
graph.erase_node(node)
|
||||||
|
|
||||||
|
# debug code, if we want to see the graph after the transformation
|
||||||
|
# with open("after.py", "w") as f:
|
||||||
|
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_backend(graph, example_inputs):
|
||||||
|
from torch._inductor import config
|
||||||
|
current_config = config.shallow_copy_dict()
|
||||||
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
current_config['post_grad_custom_post_pass'] = fix_functionalization
|
||||||
|
return compile_fx(graph, example_inputs, config_patches=current_config)
|
||||||
@ -1064,8 +1064,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
"This may lead to less accurate results!")
|
"This may lead to less accurate results!")
|
||||||
|
|
||||||
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
|
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
|
||||||
|
from vllm.compilation.backends import vllm_backend
|
||||||
from vllm.plugins import get_torch_compile_backend
|
from vllm.plugins import get_torch_compile_backend
|
||||||
backend = get_torch_compile_backend() or "eager"
|
backend = get_torch_compile_backend() or vllm_backend
|
||||||
self.model = torch.compile(
|
self.model = torch.compile(
|
||||||
self.model,
|
self.model,
|
||||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user