diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b59a4a9dd1527..02e974b0f9e8c 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper: compiled_ptr = self.check_invariants_and_forward + aot_context = nullcontext() if envs.VLLM_USE_AOT_COMPILE: if hasattr(torch._dynamo.config, "enable_aot_compile"): - torch._dynamo.config.enable_aot_compile = True + aot_context = torch._dynamo.config.patch(enable_aot_compile=True) else: msg = "torch._dynamo.config.enable_aot_compile is not " msg += "available. AOT compile is disabled and please " msg += "upgrade PyTorch version to use AOT compile." logger.warning(msg) - self._compiled_callable = torch.compile( - compiled_ptr, - fullgraph=True, - dynamic=False, - backend=backend, - options=options, - ) + with aot_context: + self._compiled_callable = torch.compile( + compiled_ptr, + fullgraph=True, + dynamic=False, + backend=backend, + options=options, + ) if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)