[torch.compile] add logging for compilation time (#10941)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
youkaichao 2024-12-06 02:07:15 -08:00 committed by GitHub
parent db87eb6c67
commit b031a455a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 75 additions and 10 deletions

View File

@ -1,5 +1,6 @@
import copy
import dataclasses
import time
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch
@ -14,6 +15,7 @@ from vllm.utils import weak_ref_tensors
from .counter import compilation_counter
from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile
from .pass_manager import PostGradPassManager
logger = init_logger(__name__)
@ -22,20 +24,21 @@ logger = init_logger(__name__)
def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
do_logging=False,
compilation_config: CompilationConfig,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
use_inductor: bool = True):
if graph_index == 0:
# before compiling the first graph, record the start time
global compilation_start_time
compilation_start_time = time.time()
if not use_inductor:
return graph
compilation_counter.num_inductor_compilations += 1
if do_logging:
if runtime_shape is None:
logger.info("Compiling a graph for general shape")
else:
logger.info("Compiling a graph for shape %s", runtime_shape)
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx
@ -52,7 +55,23 @@ def wrap_inductor(graph,
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
return compile_fx(graph, example_inputs, config_patches=current_config)
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s",
elapsed)
else:
logger.info("Compiling a graph for shape %s takes %.2f s",
runtime_shape, elapsed)
return compiled_graph
@dataclasses.dataclass
@ -114,6 +133,8 @@ def split_graph(graph: fx.GraphModule,
# we share the global graph pool among all the backends
global_graph_pool = None
compilation_start_time = 0.0
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
@ -157,12 +178,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time
compiled_graph_for_general_shape = wrap_inductor(
submod,
args,
self.compilation_configs.inductor_compile_config,
self.compilation_configs,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
do_logging=index == 0,
use_inductor=self.compilation_configs.use_inductor)
self.module.__dict__[target] = PiecewiseBackend(
@ -379,6 +403,8 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
self.capture_sizes)
for shape in self.compile_sizes.union(self.capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
@ -389,6 +415,9 @@ class PiecewiseBackend:
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
# no specific sizes to compile
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs)
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
@ -403,15 +432,22 @@ class PiecewiseBackend:
if entry.need_to_compile and not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = wrap_inductor(
self.graph,
args,
self.compilation_configs.inductor_compile_config,
self.compilation_configs,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
do_logging=self.is_first_graph,
use_inductor=self.compilation_configs.use_inductor)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs)
if not entry.use_cudagraph:
return entry.runnable(*args)

View File

@ -11,6 +11,8 @@ from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
from .monitor import start_monitoring_torch_compile
logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[nn.Module])
@ -155,6 +157,9 @@ def _support_torch_compile(
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
start_monitoring_torch_compile(vllm_config.compilation_config)
cls.__init__ = __init__
def __call__(self, *args, **kwargs):

View File

@ -0,0 +1,14 @@
from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger
logger = init_logger(__name__)
def start_monitoring_torch_compile(compilation_config: CompilationConfig):
pass
def end_monitoring_torch_compile(compilation_config: CompilationConfig):
if compilation_config.level == CompilationLevel.PIECEWISE:
logger.info("graph compilation takes %.2f s in total",
compilation_config.compilation_time)

View File

@ -2281,6 +2281,7 @@ class CompilationConfig(BaseModel):
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
compilation_time: float = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
@ -2319,6 +2320,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
self.static_forward_context = {}
self.compilation_time = 0.0
def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:

View File

@ -473,6 +473,7 @@ class LLMEngine:
The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache.
"""
start = time.time()
num_gpu_blocks, num_cpu_blocks = (
self.model_executor.determine_num_available_blocks())
@ -488,6 +489,9 @@ class LLMEngine:
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
@classmethod
def _get_executor_cls(cls,

View File

@ -67,6 +67,7 @@ class EngineCore:
def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
start = time.time()
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
)
@ -80,6 +81,9 @@ class EngineCore:
num_cpu_blocks = 0
self.model_executor.initialize_cache(num_gpu_blocks)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks
def add_request(self, request: EngineCoreRequest):