From c719c40540a85c1e6aeee9af20f29db581da27f0 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 3 Dec 2025 13:15:50 +0800 Subject: [PATCH] [Bugfix] Defunctionalize TRTLLM AR+Norm op for avoiding extra clone kernel before it (#29631) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Luka Govedič Co-authored-by: Luka Govedič --- vllm/compilation/fix_functionalization.py | 12 ++++++++++++ vllm/compilation/fx_utils.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 126ad35e527a..76068f86ebfb 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass): ]: mutated_args = {1: "result"} self.defunctionalize(graph, node, mutated_args) + elif ( + at_target + == torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default + ): + mutated_args = { + 1: "allreduce_in", + 2: "residual", + 3: "norm_out", + 4: "quant_out", + 5: "scale_out", + } + self.defunctionalize(graph, node, mutated_args) # For some reason we need to specify the args for both # silu_and_mul and silu_and_mul_quant. The kwargs # pathway gets the wrong answer. diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index f2497950fc22..3650ee6b4174 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -75,8 +75,8 @@ def find_op_nodes( return assert isinstance(op, OpOverload) - if not op._schema.is_mutable: - yield from graph.find_nodes(op="call_function", target=op) + + yield from graph.find_nodes(op="call_function", target=op) for n in graph.find_nodes(op="call_function", target=auto_functionalized): if n.args[0] == op: