mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 15:14:53 +08:00
[Bugfix] Defunctionalize TRTLLM AR+Norm op for avoiding extra clone kernel before it (#29631)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
b08025a83b
commit
c719c40540
@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
]:
|
]:
|
||||||
mutated_args = {1: "result"}
|
mutated_args = {1: "result"}
|
||||||
self.defunctionalize(graph, node, mutated_args)
|
self.defunctionalize(graph, node, mutated_args)
|
||||||
|
elif (
|
||||||
|
at_target
|
||||||
|
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||||
|
):
|
||||||
|
mutated_args = {
|
||||||
|
1: "allreduce_in",
|
||||||
|
2: "residual",
|
||||||
|
3: "norm_out",
|
||||||
|
4: "quant_out",
|
||||||
|
5: "scale_out",
|
||||||
|
}
|
||||||
|
self.defunctionalize(graph, node, mutated_args)
|
||||||
# For some reason we need to specify the args for both
|
# For some reason we need to specify the args for both
|
||||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||||
# pathway gets the wrong answer.
|
# pathway gets the wrong answer.
|
||||||
|
|||||||
@ -75,8 +75,8 @@ def find_op_nodes(
|
|||||||
return
|
return
|
||||||
|
|
||||||
assert isinstance(op, OpOverload)
|
assert isinstance(op, OpOverload)
|
||||||
if not op._schema.is_mutable:
|
|
||||||
yield from graph.find_nodes(op="call_function", target=op)
|
yield from graph.find_nodes(op="call_function", target=op)
|
||||||
|
|
||||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||||
if n.args[0] == op:
|
if n.args[0] == op:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user