mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:55:01 +08:00
[Compile] Refactor. Move PostGradPassManager out of Compilation config (#29340)
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
c32a18cbe7
commit
e7d776273d
@ -11,6 +11,7 @@ import pprint
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -429,7 +430,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
self.vllm_backend.compiler_manager.compile(
|
self.vllm_backend.compiler_manager.compile(
|
||||||
submod,
|
submod,
|
||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.vllm_backend.inductor_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
graph_index=index,
|
graph_index=index,
|
||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
@ -531,6 +532,9 @@ class VllmBackend:
|
|||||||
sym_tensor_indices: list[int]
|
sym_tensor_indices: list[int]
|
||||||
input_buffers: list[torch.Tensor]
|
input_buffers: list[torch.Tensor]
|
||||||
compiler_manager: CompilerManager
|
compiler_manager: CompilerManager
|
||||||
|
# Copy of CompilationConfig.inductor_compile_config +
|
||||||
|
# an entry for PostGradPassManager
|
||||||
|
inductor_config: dict[str, Any]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -561,25 +565,30 @@ class VllmBackend:
|
|||||||
self.compilation_config
|
self.compilation_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deepcopy the inductor config to detach the post-grad custom pass
|
||||||
|
# from CompilationConfig.
|
||||||
|
# We want to avoid PostGradPassManager in CompilationConfig because
|
||||||
|
# in future we need PostGradPassManager.uuid() to be executed
|
||||||
|
# only at compile time.
|
||||||
|
self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
|
||||||
# `torch.compile` is JIT compiled, so we don't need to
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
# do anything here
|
# do anything here
|
||||||
|
|
||||||
def configure_post_pass(self):
|
def configure_post_pass(self):
|
||||||
config = self.compilation_config
|
|
||||||
self.pass_manager.configure(self.vllm_config)
|
self.pass_manager.configure(self.vllm_config)
|
||||||
|
|
||||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||||
inductor_config = config.inductor_compile_config
|
if self.pass_key in self.inductor_config:
|
||||||
if self.pass_key in inductor_config:
|
if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
|
||||||
if isinstance(inductor_config[self.pass_key], PostGradPassManager):
|
raise ValueError(
|
||||||
# PassManager already added to config, make sure it's correct
|
"PostGradPassManager can not be kept in CompilationConfig."
|
||||||
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
|
)
|
||||||
else:
|
else:
|
||||||
# Config should automatically wrap all inductor passes
|
# Config should automatically wrap all inductor passes
|
||||||
assert isinstance(inductor_config[self.pass_key], InductorPass)
|
assert isinstance(self.inductor_config[self.pass_key], InductorPass)
|
||||||
self.pass_manager.add(inductor_config[self.pass_key])
|
self.pass_manager.add(self.inductor_config[self.pass_key])
|
||||||
inductor_config[self.pass_key] = self.pass_manager
|
self.inductor_config[self.pass_key] = self.pass_manager
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, graph: fx.GraphModule, example_inputs
|
self, graph: fx.GraphModule, example_inputs
|
||||||
@ -638,9 +647,7 @@ class VllmBackend:
|
|||||||
self.compilation_config.local_cache_dir = local_cache_dir
|
self.compilation_config.local_cache_dir = local_cache_dir
|
||||||
|
|
||||||
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
|
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
|
||||||
disable_cache = not is_compile_cache_enabled(
|
disable_cache = not is_compile_cache_enabled(self.inductor_config)
|
||||||
self.compilation_config.inductor_compile_config
|
|
||||||
)
|
|
||||||
|
|
||||||
if disable_cache:
|
if disable_cache:
|
||||||
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
|
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
|
||||||
|
|||||||
@ -107,7 +107,7 @@ class PiecewiseBackend:
|
|||||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||||
self.graph,
|
self.graph,
|
||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.vllm_backend.inductor_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
graph_index=self.piecewise_compile_index,
|
graph_index=self.piecewise_compile_index,
|
||||||
num_graphs=self.total_piecewise_compiles,
|
num_graphs=self.total_piecewise_compiles,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user