[Bugfix] Defunctionalize TRTLLM AR+Norm op for avoiding extra clone kernel before it (#29631)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-12-03 13:15:50 +08:00 committed by GitHub
parent b08025a83b
commit c719c40540
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass):
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.

View File

@ -75,8 +75,8 @@ def find_op_nodes(
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op: