diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f5fff344a1f4..c3c670422def 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -329,7 +329,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args ] - return super().run(*fake_args) + with self.fake_mode: + return super().run(*fake_args) def call_module(self, target: torch.fx.node.Target, args: Tuple[torch.fx.node.Argument,