From 91642db952458fbb6ae7c2d167757dc86b105991 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 11 Dec 2024 10:43:05 -0800 Subject: [PATCH] [torch.compile] use depyf to dump torch.compile internals (#10972) Signed-off-by: youkaichao --- requirements-common.txt | 1 + vllm/compilation/backends.py | 69 ++++++++++++++++++---------------- vllm/compilation/decorators.py | 2 +- vllm/compilation/monitor.py | 23 ++++++++++-- vllm/compilation/wrapper.py | 4 +- vllm/config.py | 6 ++- vllm/worker/model_runner.py | 3 +- 7 files changed, 66 insertions(+), 42 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 792cd58e80669..850b8f4101701 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -33,3 +33,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. compressed-tensors == 0.8.0 # required for compressed-tensors +depyf==0.18.0 # required for profiling and debugging torch.compile diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f002a8ff905b1..09a3daa731829 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -9,7 +9,7 @@ import torch import torch.fx as fx import vllm.envs as envs -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from vllm.utils import weak_ref_tensors @@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): """ def __init__(self, module: torch.fx.GraphModule, - compile_submod_names: List[str], - compilation_configs: CompilationConfig, graph_pool): + compile_submod_names: List[str], vllm_config: VllmConfig, + graph_pool): super().__init__(module) from torch._guards import detect_fake_mode self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names - self.compilation_configs = compilation_configs + self.compilation_config = vllm_config.compilation_config self.graph_pool = graph_pool + self.vllm_config = vllm_config def run(self, *args): fake_args = [ @@ -182,15 +183,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): compiled_graph_for_general_shape = wrap_inductor( submod, args, - self.compilation_configs.inductor_compile_config, - self.compilation_configs, + self.compilation_config.inductor_compile_config, + self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None, - use_inductor=self.compilation_configs.use_inductor) + use_inductor=self.compilation_config.use_inductor) self.module.__dict__[target] = PiecewiseBackend( - submod, self.compilation_configs, self.graph_pool, index, + submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape) @@ -211,7 +212,8 @@ class VllmBackend: which handles the post-grad passes. """ - compilation_configs: CompilationConfig + vllm_config: VllmConfig + compilation_config: CompilationConfig graph_pool: Any _called: bool = False # the graph we compiled @@ -227,7 +229,7 @@ class VllmBackend: def __init__( self, - compilation_configs: CompilationConfig, + vllm_config: VllmConfig, ): global global_graph_pool if global_graph_pool is None: @@ -244,13 +246,14 @@ class VllmBackend: self.sym_tensor_indices = [] self.input_buffers = [] - self.compilation_configs = compilation_configs + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config # `torch.compile` is JIT compiled, so we don't need to # do anything here def configure_post_pass(self): - config = self.compilation_configs + config = self.compilation_config self.post_grad_pass_manager.configure(config.pass_config) # Post-grad custom passes are run using the post_grad_custom_post_pass @@ -271,7 +274,7 @@ class VllmBackend: from .monitor import torch_compile_start_time dynamo_time = time.time() - torch_compile_start_time logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) - self.compilation_configs.compilation_time += dynamo_time + self.compilation_config.compilation_time += dynamo_time # we control the compilation process, each instance can only be # called once @@ -281,7 +284,7 @@ class VllmBackend: self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_configs.splitting_ops) + graph, self.compilation_config.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) @@ -298,13 +301,13 @@ class VllmBackend: # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.compilation_configs, + self.vllm_config, self.graph_pool).run(*example_inputs) self._called = True - if not self.compilation_configs.use_cudagraph or \ - not self.compilation_configs.cudagraph_copy_inputs: + if not self.compilation_config.use_cudagraph or \ + not self.compilation_config.cudagraph_copy_inputs: return self.split_gm # if we need to copy input buffers for cudagraph @@ -364,10 +367,9 @@ class ConcreteSizeEntry: class PiecewiseBackend: - def __init__(self, graph: fx.GraphModule, - compilation_configs: CompilationConfig, graph_pool: Any, - piecewise_compile_index: int, total_piecewise_compiles: int, - sym_shape_indices: List[int], + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: List[int], compiled_graph_for_general_shape: Callable): """ The backend for piecewise compilation. @@ -375,7 +377,7 @@ class PiecewiseBackend: We will compile `self.graph` once for the general shape, and then compile for different shapes specified in - `compilation_configs.compile_sizes`. + `compilation_config.compile_sizes`. Independently, we will capture cudagraph for different shapes. @@ -383,7 +385,8 @@ class PiecewiseBackend: compile it first, and then capture cudagraph. """ self.graph = graph - self.compilation_configs = compilation_configs + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles @@ -393,10 +396,10 @@ class PiecewiseBackend: piecewise_compile_index == total_piecewise_compiles - 1) self.compile_sizes: Set[int] = set( - self.compilation_configs.compile_sizes) + self.compilation_config.compile_sizes) self.capture_sizes: Set[int] = set( - self.compilation_configs.capture_sizes - ) if self.compilation_configs.use_cudagraph else set() + self.compilation_config.capture_sizes + ) if self.compilation_config.use_cudagraph else set() self.first_run_finished = False @@ -423,7 +426,7 @@ class PiecewiseBackend: self.first_run_finished = True # no specific sizes to compile if self.is_last_graph and not self.to_be_compiled_sizes: - end_monitoring_torch_compile(self.compilation_configs) + end_monitoring_torch_compile(self.vllm_config) return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] @@ -443,28 +446,28 @@ class PiecewiseBackend: entry.runnable = wrap_inductor( self.graph, args, - self.compilation_configs.inductor_compile_config, - self.compilation_configs, + self.compilation_config.inductor_compile_config, + self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape, - use_inductor=self.compilation_configs.use_inductor) + use_inductor=self.compilation_config.use_inductor) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: - end_monitoring_torch_compile(self.compilation_configs) + end_monitoring_torch_compile(self.vllm_config) if not entry.use_cudagraph: return entry.runnable(*args) if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa entry.num_finished_warmup += 1 if self.is_first_graph: logger.debug( "Warming up %s/%s for shape %s", entry.num_finished_warmup, - self.compilation_configs.cudagraph_num_of_warmups, + self.compilation_config.cudagraph_num_of_warmups, runtime_shape) return entry.runnable(*args) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 938430fe2a501..805a217ee6ca1 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -185,7 +185,7 @@ def _support_torch_compile( "Unsupported dynamic dimensions" f" {dims} for argument {k} with type {type(arg)}.") # here, it is the starting point of the `torch.compile` process - start_monitoring_torch_compile(self.vllm_config.compilation_config) + start_monitoring_torch_compile(self.vllm_config) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 3348674b09af2..b97e40415b41b 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,19 +1,36 @@ +import os import time -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) +context_manager = None torch_compile_start_time: float = 0.0 -def start_monitoring_torch_compile(compilation_config: CompilationConfig): +def start_monitoring_torch_compile(vllm_config: VllmConfig): global torch_compile_start_time torch_compile_start_time = time.time() + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE and \ + compilation_config.debug_dump_path: + import depyf + path = os.path.join(compilation_config.debug_dump_path, + f"rank_{vllm_config.parallel_config.rank}") + global context_manager + context_manager = depyf.prepare_debug(path) + context_manager.__enter__() -def end_monitoring_torch_compile(compilation_config: CompilationConfig): + +def end_monitoring_torch_compile(vllm_config: VllmConfig): + compilation_config: CompilationConfig = vllm_config.compilation_config if compilation_config.level == CompilationLevel.PIECEWISE: logger.info("torch.compile takes %.2f s in total", compilation_config.compilation_time) + global context_manager + if context_manager is not None: + context_manager.__exit__(None, None, None) + context_manager = None diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index bc4d292fef402..c10241b483169 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -32,8 +32,8 @@ class TorchCompileWrapperWithCustomDispatcher: # default compilation settings # compiling the forward method - backend = get_current_vllm_config( - ).compilation_config.init_backend() + vllm_config = get_current_vllm_config() + backend = vllm_config.compilation_config.init_backend(vllm_config) compiled_callable = torch.compile( self.forward, diff --git a/vllm/config.py b/vllm/config.py index 322c8f8990a40..7f9be5a3a98bc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2222,6 +2222,7 @@ class CompilationConfig(BaseModel): - 1: dynamo as is. - 2: dynamo once. - 3: piecewise compilation. + - debug_dump_path: the path to dump the debug information. - backend: the backend for compilation. It needs to be a string. - "" (empty string): use the default backend. - "eager"/"openxla"/...: use the specified backend registered in PyTorch. @@ -2289,6 +2290,7 @@ class CompilationConfig(BaseModel): certain small batchsizes, where inductor is good at optimizing. """ # noqa level: int = 0 + debug_dump_path: str = "" backend: str = "" custom_ops: List[str] = Field(default_factory=list) splitting_ops: List[str] = Field(default_factory=lambda: [ @@ -2394,7 +2396,7 @@ class CompilationConfig(BaseModel): self.static_forward_context = {} self.compilation_time = 0.0 - def init_backend(self) -> Union[str, Callable]: + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -2413,7 +2415,7 @@ class CompilationConfig(BaseModel): # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE from vllm.compilation.backends import VllmBackend - return VllmBackend(self) + return VllmBackend(vllm_config) def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): """To complete the initialization of config, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 551b84435fdc0..26fd486130ce6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1162,7 +1162,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): - backend = self.vllm_config.compilation_config.init_backend() + backend = self.vllm_config.compilation_config.init_backend( + self.vllm_config) self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,