mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 05:07:13 +08:00
[torch.compile] use depyf to dump torch.compile internals (#10972)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
fd22220687
commit
91642db952
@ -33,3 +33,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
|
||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.8.0 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging torch.compile
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: List[str],
|
||||
compilation_configs: CompilationConfig, graph_pool):
|
||||
compile_submod_names: List[str], vllm_config: VllmConfig,
|
||||
graph_pool):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_configs = compilation_configs
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
@ -182,15 +183,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
compiled_graph_for_general_shape = wrap_inductor(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_configs.inductor_compile_config,
|
||||
self.compilation_configs,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
use_inductor=self.compilation_configs.use_inductor)
|
||||
use_inductor=self.compilation_config.use_inductor)
|
||||
|
||||
self.module.__dict__[target] = PiecewiseBackend(
|
||||
submod, self.compilation_configs, self.graph_pool, index,
|
||||
submod, self.vllm_config, self.graph_pool, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_general_shape)
|
||||
|
||||
@ -211,7 +212,8 @@ class VllmBackend:
|
||||
which handles the post-grad passes.
|
||||
"""
|
||||
|
||||
compilation_configs: CompilationConfig
|
||||
vllm_config: VllmConfig
|
||||
compilation_config: CompilationConfig
|
||||
graph_pool: Any
|
||||
_called: bool = False
|
||||
# the graph we compiled
|
||||
@ -227,7 +229,7 @@ class VllmBackend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compilation_configs: CompilationConfig,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
global global_graph_pool
|
||||
if global_graph_pool is None:
|
||||
@ -244,13 +246,14 @@ class VllmBackend:
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
self.compilation_configs = compilation_configs
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_configs
|
||||
config = self.compilation_config
|
||||
self.post_grad_pass_manager.configure(config.pass_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
@ -271,7 +274,7 @@ class VllmBackend:
|
||||
from .monitor import torch_compile_start_time
|
||||
dynamo_time = time.time() - torch_compile_start_time
|
||||
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
||||
self.compilation_configs.compilation_time += dynamo_time
|
||||
self.compilation_config.compilation_time += dynamo_time
|
||||
|
||||
# we control the compilation process, each instance can only be
|
||||
# called once
|
||||
@ -281,7 +284,7 @@ class VllmBackend:
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_configs.splitting_ops)
|
||||
graph, self.compilation_config.splitting_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
|
||||
@ -298,13 +301,13 @@ class VllmBackend:
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.compilation_configs,
|
||||
self.vllm_config,
|
||||
self.graph_pool).run(*example_inputs)
|
||||
|
||||
self._called = True
|
||||
|
||||
if not self.compilation_configs.use_cudagraph or \
|
||||
not self.compilation_configs.cudagraph_copy_inputs:
|
||||
if not self.compilation_config.use_cudagraph or \
|
||||
not self.compilation_config.cudagraph_copy_inputs:
|
||||
return self.split_gm
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
@ -364,10 +367,9 @@ class ConcreteSizeEntry:
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule,
|
||||
compilation_configs: CompilationConfig, graph_pool: Any,
|
||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||
sym_shape_indices: List[int],
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
||||
compiled_graph_for_general_shape: Callable):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
@ -375,7 +377,7 @@ class PiecewiseBackend:
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_configs.compile_sizes`.
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
Independently, we will capture cudagraph for different shapes.
|
||||
|
||||
@ -383,7 +385,8 @@ class PiecewiseBackend:
|
||||
compile it first, and then capture cudagraph.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.compilation_configs = compilation_configs
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
@ -393,10 +396,10 @@ class PiecewiseBackend:
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.compile_sizes: Set[int] = set(
|
||||
self.compilation_configs.compile_sizes)
|
||||
self.compilation_config.compile_sizes)
|
||||
self.capture_sizes: Set[int] = set(
|
||||
self.compilation_configs.capture_sizes
|
||||
) if self.compilation_configs.use_cudagraph else set()
|
||||
self.compilation_config.capture_sizes
|
||||
) if self.compilation_config.use_cudagraph else set()
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
@ -423,7 +426,7 @@ class PiecewiseBackend:
|
||||
self.first_run_finished = True
|
||||
# no specific sizes to compile
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
end_monitoring_torch_compile(self.compilation_configs)
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
@ -443,28 +446,28 @@ class PiecewiseBackend:
|
||||
entry.runnable = wrap_inductor(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_configs.inductor_compile_config,
|
||||
self.compilation_configs,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape,
|
||||
use_inductor=self.compilation_configs.use_inductor)
|
||||
use_inductor=self.compilation_config.use_inductor)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
end_monitoring_torch_compile(self.compilation_configs)
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa
|
||||
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
||||
entry.num_finished_warmup += 1
|
||||
if self.is_first_graph:
|
||||
logger.debug(
|
||||
"Warming up %s/%s for shape %s",
|
||||
entry.num_finished_warmup,
|
||||
self.compilation_configs.cudagraph_num_of_warmups,
|
||||
self.compilation_config.cudagraph_num_of_warmups,
|
||||
runtime_shape)
|
||||
return entry.runnable(*args)
|
||||
|
||||
|
||||
@ -185,7 +185,7 @@ def _support_torch_compile(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}.")
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config.compilation_config)
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
|
||||
@ -1,19 +1,36 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(compilation_config: CompilationConfig):
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.time()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and \
|
||||
compilation_config.debug_dump_path:
|
||||
import depyf
|
||||
path = os.path.join(compilation_config.debug_dump_path,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path)
|
||||
context_manager.__enter__()
|
||||
|
||||
def end_monitoring_torch_compile(compilation_config: CompilationConfig):
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
logger.info("torch.compile takes %.2f s in total",
|
||||
compilation_config.compilation_time)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
@ -32,8 +32,8 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
backend = get_current_vllm_config(
|
||||
).compilation_config.init_backend()
|
||||
vllm_config = get_current_vllm_config()
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
|
||||
@ -2222,6 +2222,7 @@ class CompilationConfig(BaseModel):
|
||||
- 1: dynamo as is.
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation.
|
||||
- debug_dump_path: the path to dump the debug information.
|
||||
- backend: the backend for compilation. It needs to be a string.
|
||||
- "" (empty string): use the default backend.
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
@ -2289,6 +2290,7 @@ class CompilationConfig(BaseModel):
|
||||
certain small batchsizes, where inductor is good at optimizing.
|
||||
""" # noqa
|
||||
level: int = 0
|
||||
debug_dump_path: str = ""
|
||||
backend: str = ""
|
||||
custom_ops: List[str] = Field(default_factory=list)
|
||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||
@ -2394,7 +2396,7 @@ class CompilationConfig(BaseModel):
|
||||
self.static_forward_context = {}
|
||||
self.compilation_time = 0.0
|
||||
|
||||
def init_backend(self) -> Union[str, Callable]:
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
|
||||
@ -2413,7 +2415,7 @@ class CompilationConfig(BaseModel):
|
||||
# merge with the config use_inductor
|
||||
assert self.level == CompilationLevel.PIECEWISE
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
return VllmBackend(self)
|
||||
return VllmBackend(vllm_config)
|
||||
|
||||
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
|
||||
"""To complete the initialization of config,
|
||||
|
||||
@ -1162,7 +1162,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
backend = self.vllm_config.compilation_config.init_backend()
|
||||
backend = self.vllm_config.compilation_config.init_backend(
|
||||
self.vllm_config)
|
||||
self.model = torch.compile(
|
||||
self.model,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user