[Compile] Refactor. Move PostGradPassManager out of Compilation config (#29340)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
Ilya Markov 2025-11-25 20:58:56 +01:00 committed by GitHub
parent c32a18cbe7
commit e7d776273d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 14 deletions

View File

@ -11,6 +11,7 @@ import pprint
import time
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from typing import Any
@ -429,7 +430,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.compilation_config.inductor_compile_config,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
@ -531,6 +532,9 @@ class VllmBackend:
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager
# Copy of CompilationConfig.inductor_compile_config +
# an entry for PostGradPassManager
inductor_config: dict[str, Any]
def __init__(
self,
@ -561,25 +565,30 @@ class VllmBackend:
self.compilation_config
)
# Deepcopy the inductor config to detach the post-grad custom pass
# from CompilationConfig.
# We want to avoid PostGradPassManager in CompilationConfig because
# in future we need PostGradPassManager.uuid() to be executed
# only at compile time.
self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def configure_post_pass(self):
config = self.compilation_config
self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config
if self.pass_key in inductor_config:
if isinstance(inductor_config[self.pass_key], PostGradPassManager):
# PassManager already added to config, make sure it's correct
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
if self.pass_key in self.inductor_config:
if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
raise ValueError(
"PostGradPassManager can not be kept in CompilationConfig."
)
else:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(inductor_config[self.pass_key])
inductor_config[self.pass_key] = self.pass_manager
assert isinstance(self.inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(self.inductor_config[self.pass_key])
self.inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs
@ -638,9 +647,7 @@ class VllmBackend:
self.compilation_config.local_cache_dir = local_cache_dir
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(
self.compilation_config.inductor_compile_config
)
disable_cache = not is_compile_cache_enabled(self.inductor_config)
if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")

View File

@ -107,7 +107,7 @@ class PiecewiseBackend:
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.compilation_config.inductor_compile_config,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,