diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 126ad35e527a..76068f86ebfb 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass): ]: mutated_args = {1: "result"} self.defunctionalize(graph, node, mutated_args) + elif ( + at_target + == torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default + ): + mutated_args = { + 1: "allreduce_in", + 2: "residual", + 3: "norm_out", + 4: "quant_out", + 5: "scale_out", + } + self.defunctionalize(graph, node, mutated_args) # For some reason we need to specify the args for both # silu_and_mul and silu_and_mul_quant. The kwargs # pathway gets the wrong answer. diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index f2497950fc22..3650ee6b4174 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -75,8 +75,8 @@ def find_op_nodes( return assert isinstance(op, OpOverload) - if not op._schema.is_mutable: - yield from graph.find_nodes(op="call_function", target=op) + + yield from graph.find_nodes(op="call_function", target=op) for n in graph.find_nodes(op="call_function", target=auto_functionalized): if n.args[0] == op: