vllm/vllm/compilation/piecewise_backend.py
Ilya Markov e7d776273d
[Compile] Refactor. Move PostGradPassManager out of Compilation config (#29340)
Signed-off-by: ilmarkov <markovilya197@gmail.com>
2025-11-25 19:58:56 +00:00

122 lines
4.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections.abc import Callable
from typing import Any
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
compiled: bool = False
runnable: Callable = None # type: ignore
class PiecewiseBackend:
def __init__(
self,
graph: fx.GraphModule,
vllm_config: VllmConfig,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend,
):
"""
The backend for piecewise compilation.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
"""
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
self.first_run_finished = False
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to compile
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
# We only keep compilation management inside this class directly.
for shape in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
runnable=self.compiled_graph_for_general_shape,
)
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
return entry.runnable(*args)