diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 2d8dd4c51c7ef..1773913d0b6c6 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -11,6 +11,7 @@ import pprint import time from collections.abc import Callable, Sequence from contextlib import contextmanager +from copy import deepcopy from functools import partial from typing import Any @@ -429,7 +430,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): self.vllm_backend.compiler_manager.compile( submod, args, - self.compilation_config.inductor_compile_config, + self.vllm_backend.inductor_config, self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), @@ -531,6 +532,9 @@ class VllmBackend: sym_tensor_indices: list[int] input_buffers: list[torch.Tensor] compiler_manager: CompilerManager + # Copy of CompilationConfig.inductor_compile_config + + # an entry for PostGradPassManager + inductor_config: dict[str, Any] def __init__( self, @@ -561,25 +565,30 @@ class VllmBackend: self.compilation_config ) + # Deepcopy the inductor config to detach the post-grad custom pass + # from CompilationConfig. + # We want to avoid PostGradPassManager in CompilationConfig because + # in future we need PostGradPassManager.uuid() to be executed + # only at compile time. + self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config) # `torch.compile` is JIT compiled, so we don't need to # do anything here def configure_post_pass(self): - config = self.compilation_config self.pass_manager.configure(self.vllm_config) # Post-grad custom passes are run using the post_grad_custom_post_pass # hook. If a pass for that hook exists, add it to the pass manager. - inductor_config = config.inductor_compile_config - if self.pass_key in inductor_config: - if isinstance(inductor_config[self.pass_key], PostGradPassManager): - # PassManager already added to config, make sure it's correct - assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid() + if self.pass_key in self.inductor_config: + if isinstance(self.inductor_config[self.pass_key], PostGradPassManager): + raise ValueError( + "PostGradPassManager can not be kept in CompilationConfig." + ) else: # Config should automatically wrap all inductor passes - assert isinstance(inductor_config[self.pass_key], InductorPass) - self.pass_manager.add(inductor_config[self.pass_key]) - inductor_config[self.pass_key] = self.pass_manager + assert isinstance(self.inductor_config[self.pass_key], InductorPass) + self.pass_manager.add(self.inductor_config[self.pass_key]) + self.inductor_config[self.pass_key] = self.pass_manager def __call__( self, graph: fx.GraphModule, example_inputs @@ -638,9 +647,7 @@ class VllmBackend: self.compilation_config.local_cache_dir = local_cache_dir # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. - disable_cache = not is_compile_cache_enabled( - self.compilation_config.inductor_compile_config - ) + disable_cache = not is_compile_cache_enabled(self.inductor_config) if disable_cache: logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb0..e535d2c461c6e 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -107,7 +107,7 @@ class PiecewiseBackend: entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, - self.compilation_config.inductor_compile_config, + self.vllm_backend.inductor_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles,