[Compile] Refactor. Move PostGradPassManager out of Compilation config (#29340)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
Ilya Markov 2025-11-25 20:58:56 +01:00 committed by GitHub
parent c32a18cbe7
commit e7d776273d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 14 deletions

View File

@ -11,6 +11,7 @@ import pprint
import time import time
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy
from functools import partial from functools import partial
from typing import Any from typing import Any
@ -429,7 +430,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.vllm_backend.compiler_manager.compile( self.vllm_backend.compiler_manager.compile(
submod, submod,
args, args,
self.compilation_config.inductor_compile_config, self.vllm_backend.inductor_config,
self.compilation_config, self.compilation_config,
graph_index=index, graph_index=index,
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
@ -531,6 +532,9 @@ class VllmBackend:
sym_tensor_indices: list[int] sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor] input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager compiler_manager: CompilerManager
# Copy of CompilationConfig.inductor_compile_config +
# an entry for PostGradPassManager
inductor_config: dict[str, Any]
def __init__( def __init__(
self, self,
@ -561,25 +565,30 @@ class VllmBackend:
self.compilation_config self.compilation_config
) )
# Deepcopy the inductor config to detach the post-grad custom pass
# from CompilationConfig.
# We want to avoid PostGradPassManager in CompilationConfig because
# in future we need PostGradPassManager.uuid() to be executed
# only at compile time.
self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
def configure_post_pass(self): def configure_post_pass(self):
config = self.compilation_config
self.pass_manager.configure(self.vllm_config) self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass # Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager. # hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config if self.pass_key in self.inductor_config:
if self.pass_key in inductor_config: if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
if isinstance(inductor_config[self.pass_key], PostGradPassManager): raise ValueError(
# PassManager already added to config, make sure it's correct "PostGradPassManager can not be kept in CompilationConfig."
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid() )
else: else:
# Config should automatically wrap all inductor passes # Config should automatically wrap all inductor passes
assert isinstance(inductor_config[self.pass_key], InductorPass) assert isinstance(self.inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(inductor_config[self.pass_key]) self.pass_manager.add(self.inductor_config[self.pass_key])
inductor_config[self.pass_key] = self.pass_manager self.inductor_config[self.pass_key] = self.pass_manager
def __call__( def __call__(
self, graph: fx.GraphModule, example_inputs self, graph: fx.GraphModule, example_inputs
@ -638,9 +647,7 @@ class VllmBackend:
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled( disable_cache = not is_compile_cache_enabled(self.inductor_config)
self.compilation_config.inductor_compile_config
)
if disable_cache: if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")

View File

@ -107,7 +107,7 @@ class PiecewiseBackend:
entry.runnable = self.vllm_backend.compiler_manager.compile( entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph, self.graph,
args, args,
self.compilation_config.inductor_compile_config, self.vllm_backend.inductor_config,
self.compilation_config, self.compilation_config,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,