diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 76068f86ebfb3..2625562aadd36 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -104,7 +104,8 @@ class FixFunctionalizationPass(VllmInductorPass): mutated_args = {1: "result"} self.defunctionalize(graph, node, mutated_args) elif ( - at_target + hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm") + and at_target == torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default ): mutated_args = {