[torch.compile] use depyf to dump torch.compile internals (#10972)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-11 10:43:05 -08:00 committed by GitHub
parent fd22220687
commit 91642db952
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 66 additions and 42 deletions

View File

@ -33,3 +33,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.8.0 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging torch.compile

View File

@ -9,7 +9,7 @@ import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
"""
def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str],
compilation_configs: CompilationConfig, graph_pool):
compile_submod_names: List[str], vllm_config: VllmConfig,
graph_pool):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_configs = compilation_configs
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.vllm_config = vllm_config
def run(self, *args):
fake_args = [
@ -182,15 +183,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
compiled_graph_for_general_shape = wrap_inductor(
submod,
args,
self.compilation_configs.inductor_compile_config,
self.compilation_configs,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
use_inductor=self.compilation_configs.use_inductor)
use_inductor=self.compilation_config.use_inductor)
self.module.__dict__[target] = PiecewiseBackend(
submod, self.compilation_configs, self.graph_pool, index,
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape)
@ -211,7 +212,8 @@ class VllmBackend:
which handles the post-grad passes.
"""
compilation_configs: CompilationConfig
vllm_config: VllmConfig
compilation_config: CompilationConfig
graph_pool: Any
_called: bool = False
# the graph we compiled
@ -227,7 +229,7 @@ class VllmBackend:
def __init__(
self,
compilation_configs: CompilationConfig,
vllm_config: VllmConfig,
):
global global_graph_pool
if global_graph_pool is None:
@ -244,13 +246,14 @@ class VllmBackend:
self.sym_tensor_indices = []
self.input_buffers = []
self.compilation_configs = compilation_configs
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def configure_post_pass(self):
config = self.compilation_configs
config = self.compilation_config
self.post_grad_pass_manager.configure(config.pass_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
@ -271,7 +274,7 @@ class VllmBackend:
from .monitor import torch_compile_start_time
dynamo_time = time.time() - torch_compile_start_time
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
self.compilation_configs.compilation_time += dynamo_time
self.compilation_config.compilation_time += dynamo_time
# we control the compilation process, each instance can only be
# called once
@ -281,7 +284,7 @@ class VllmBackend:
self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.splitting_ops)
graph, self.compilation_config.splitting_ops)
from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
@ -298,13 +301,13 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.compilation_configs,
self.vllm_config,
self.graph_pool).run(*example_inputs)
self._called = True
if not self.compilation_configs.use_cudagraph or \
not self.compilation_configs.cudagraph_copy_inputs:
if not self.compilation_config.use_cudagraph or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm
# if we need to copy input buffers for cudagraph
@ -364,10 +367,9 @@ class ConcreteSizeEntry:
class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule,
compilation_configs: CompilationConfig, graph_pool: Any,
piecewise_compile_index: int, total_piecewise_compiles: int,
sym_shape_indices: List[int],
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable):
"""
The backend for piecewise compilation.
@ -375,7 +377,7 @@ class PiecewiseBackend:
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_configs.compile_sizes`.
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
@ -383,7 +385,8 @@ class PiecewiseBackend:
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.compilation_configs = compilation_configs
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
@ -393,10 +396,10 @@ class PiecewiseBackend:
piecewise_compile_index == total_piecewise_compiles - 1)
self.compile_sizes: Set[int] = set(
self.compilation_configs.compile_sizes)
self.compilation_config.compile_sizes)
self.capture_sizes: Set[int] = set(
self.compilation_configs.capture_sizes
) if self.compilation_configs.use_cudagraph else set()
self.compilation_config.capture_sizes
) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False
@ -423,7 +426,7 @@ class PiecewiseBackend:
self.first_run_finished = True
# no specific sizes to compile
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs)
end_monitoring_torch_compile(self.vllm_config)
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
@ -443,28 +446,28 @@ class PiecewiseBackend:
entry.runnable = wrap_inductor(
self.graph,
args,
self.compilation_configs.inductor_compile_config,
self.compilation_configs,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
use_inductor=self.compilation_configs.use_inductor)
use_inductor=self.compilation_config.use_inductor)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs)
end_monitoring_torch_compile(self.vllm_config)
if not entry.use_cudagraph:
return entry.runnable(*args)
if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_configs.cudagraph_num_of_warmups,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)

View File

@ -185,7 +185,7 @@ def _support_torch_compile(
"Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}.")
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config.compilation_config)
start_monitoring_torch_compile(self.vllm_config)
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,

View File

@ -1,19 +1,36 @@
import os
import time
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
context_manager = None
torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(compilation_config: CompilationConfig):
def start_monitoring_torch_compile(vllm_config: VllmConfig):
global torch_compile_start_time
torch_compile_start_time = time.time()
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE and \
compilation_config.debug_dump_path:
import depyf
path = os.path.join(compilation_config.debug_dump_path,
f"rank_{vllm_config.parallel_config.rank}")
global context_manager
context_manager = depyf.prepare_debug(path)
context_manager.__enter__()
def end_monitoring_torch_compile(compilation_config: CompilationConfig):
def end_monitoring_torch_compile(vllm_config: VllmConfig):
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE:
logger.info("torch.compile takes %.2f s in total",
compilation_config.compilation_time)
global context_manager
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None

View File

@ -32,8 +32,8 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings
# compiling the forward method
backend = get_current_vllm_config(
).compilation_config.init_backend()
vllm_config = get_current_vllm_config()
backend = vllm_config.compilation_config.init_backend(vllm_config)
compiled_callable = torch.compile(
self.forward,

View File

@ -2222,6 +2222,7 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- debug_dump_path: the path to dump the debug information.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
@ -2289,6 +2290,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing.
""" # noqa
level: int = 0
debug_dump_path: str = ""
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
@ -2394,7 +2396,7 @@ class CompilationConfig(BaseModel):
self.static_forward_context = {}
self.compilation_time = 0.0
def init_backend(self) -> Union[str, Callable]:
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -2413,7 +2415,7 @@ class CompilationConfig(BaseModel):
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)
return VllmBackend(vllm_config)
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
"""To complete the initialization of config,

View File

@ -1162,7 +1162,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
backend = self.vllm_config.compilation_config.init_backend()
backend = self.vllm_config.compilation_config.init_backend(
self.vllm_config)
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,