[torch.compile] add logging for compilation time (#10941)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
youkaichao 2024-12-06 02:07:15 -08:00 committed by GitHub
parent db87eb6c67
commit b031a455a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 75 additions and 10 deletions

View File

@ -1,5 +1,6 @@
import copy import copy
import dataclasses import dataclasses
import time
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch from unittest.mock import patch
@ -14,6 +15,7 @@ from vllm.utils import weak_ref_tensors
from .counter import compilation_counter from .counter import compilation_counter
from .inductor_pass import InductorPass from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile
from .pass_manager import PostGradPassManager from .pass_manager import PostGradPassManager
logger = init_logger(__name__) logger = init_logger(__name__)
@ -22,20 +24,21 @@ logger = init_logger(__name__)
def wrap_inductor(graph, def wrap_inductor(graph,
example_inputs, example_inputs,
additional_inductor_config, additional_inductor_config,
do_logging=False, compilation_config: CompilationConfig,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None, runtime_shape: Optional[int] = None,
use_inductor: bool = True): use_inductor: bool = True):
if graph_index == 0:
# before compiling the first graph, record the start time
global compilation_start_time
compilation_start_time = time.time()
if not use_inductor: if not use_inductor:
return graph return graph
compilation_counter.num_inductor_compilations += 1 compilation_counter.num_inductor_compilations += 1
if do_logging:
if runtime_shape is None:
logger.info("Compiling a graph for general shape")
else:
logger.info("Compiling a graph for shape %s", runtime_shape)
from torch._inductor import config from torch._inductor import config
current_config = config.shallow_copy_dict() current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
@ -52,7 +55,23 @@ def wrap_inductor(graph,
# inductor can inplace modify the graph, so we need to copy it # inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980 # see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph) graph = copy.deepcopy(graph)
return compile_fx(graph, example_inputs, config_patches=current_config) compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s",
elapsed)
else:
logger.info("Compiling a graph for shape %s takes %.2f s",
runtime_shape, elapsed)
return compiled_graph
@dataclasses.dataclass @dataclasses.dataclass
@ -114,6 +133,8 @@ def split_graph(graph: fx.GraphModule,
# we share the global graph pool among all the backends # we share the global graph pool among all the backends
global_graph_pool = None global_graph_pool = None
compilation_start_time = 0.0
class PiecewiseCompileInterpreter(torch.fx.Interpreter): class PiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
@ -157,12 +178,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
sym_shape_indices = [ sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt) i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
] ]
global compilation_start_time
compiled_graph_for_general_shape = wrap_inductor( compiled_graph_for_general_shape = wrap_inductor(
submod, submod,
args, args,
self.compilation_configs.inductor_compile_config, self.compilation_configs.inductor_compile_config,
self.compilation_configs,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None, runtime_shape=None,
do_logging=index == 0,
use_inductor=self.compilation_configs.use_inductor) use_inductor=self.compilation_configs.use_inductor)
self.module.__dict__[target] = PiecewiseBackend( self.module.__dict__[target] = PiecewiseBackend(
@ -379,6 +403,8 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either # the entries for different shapes that we need to either
# compile or capture cudagraph # compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
self.capture_sizes)
for shape in self.compile_sizes.union(self.capture_sizes): for shape in self.compile_sizes.union(self.capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry( self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape, runtime_shape=shape,
@ -389,6 +415,9 @@ class PiecewiseBackend:
def __call__(self, *args) -> Any: def __call__(self, *args) -> Any:
if not self.first_run_finished: if not self.first_run_finished:
self.first_run_finished = True self.first_run_finished = True
# no specific sizes to compile
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs)
return self.compiled_graph_for_general_shape(*args) return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]] runtime_shape = args[self.sym_shape_indices[0]]
@ -403,15 +432,22 @@ class PiecewiseBackend:
if entry.need_to_compile and not entry.compiled: if entry.need_to_compile and not entry.compiled:
entry.compiled = True entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments # args are real arguments
entry.runnable = wrap_inductor( entry.runnable = wrap_inductor(
self.graph, self.graph,
args, args,
self.compilation_configs.inductor_compile_config, self.compilation_configs.inductor_compile_config,
self.compilation_configs,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape, runtime_shape=runtime_shape,
do_logging=self.is_first_graph,
use_inductor=self.compilation_configs.use_inductor) use_inductor=self.compilation_configs.use_inductor)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs)
if not entry.use_cudagraph: if not entry.use_cudagraph:
return entry.runnable(*args) return entry.runnable(*args)

View File

@ -11,6 +11,8 @@ from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo from vllm.utils import supports_dynamo
from .monitor import start_monitoring_torch_compile
logger = init_logger(__name__) logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
@ -155,6 +157,9 @@ def _support_torch_compile(
TorchCompileWrapperWithCustomDispatcher.__init__( TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level) self, compilation_level=vllm_config.compilation_config.level)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
start_monitoring_torch_compile(vllm_config.compilation_config)
cls.__init__ = __init__ cls.__init__ = __init__
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):

View File

@ -0,0 +1,14 @@
from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger
logger = init_logger(__name__)
def start_monitoring_torch_compile(compilation_config: CompilationConfig):
pass
def end_monitoring_torch_compile(compilation_config: CompilationConfig):
if compilation_config.level == CompilationLevel.PIECEWISE:
logger.info("graph compilation takes %.2f s in total",
compilation_config.compilation_time)

View File

@ -2281,6 +2281,7 @@ class CompilationConfig(BaseModel):
# keep track of enabled and disabled custom ops # keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr
compilation_time: float = PrivateAttr
# Per-model forward context # Per-model forward context
# Mainly used to store attention cls # Mainly used to store attention cls
@ -2319,6 +2320,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter() self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter() self.disabled_custom_ops = Counter()
self.static_forward_context = {} self.static_forward_context = {}
self.compilation_time = 0.0
def init_backend(self) -> Union[str, Callable]: def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:

View File

@ -473,6 +473,7 @@ class LLMEngine:
The workers will determine the number of blocks in both the GPU cache The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache. and the swap CPU cache.
""" """
start = time.time()
num_gpu_blocks, num_cpu_blocks = ( num_gpu_blocks, num_cpu_blocks = (
self.model_executor.determine_num_available_blocks()) self.model_executor.determine_num_available_blocks())
@ -488,6 +489,9 @@ class LLMEngine:
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
@classmethod @classmethod
def _get_executor_cls(cls, def _get_executor_cls(cls,

View File

@ -67,6 +67,7 @@ class EngineCore:
def _initialize_kv_caches(self, def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]: cache_config: CacheConfig) -> Tuple[int, int]:
start = time.time()
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
) )
@ -80,6 +81,9 @@ class EngineCore:
num_cpu_blocks = 0 num_cpu_blocks = 0
self.model_executor.initialize_cache(num_gpu_blocks) self.model_executor.initialize_cache(num_gpu_blocks)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
def add_request(self, request: EngineCoreRequest): def add_request(self, request: EngineCoreRequest):