[torch.compile] consider relevant code in compilation cache (#11614)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-08 18:46:43 +08:00 committed by GitHub
parent cfd3219f58
commit f12141170a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 99 additions and 35 deletions

View File

@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
example_inputs, example_inputs,
additional_inductor_config, additional_inductor_config,
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
vllm_backend: "VllmBackend",
graph_index: int = 0, graph_index: int = 0,
num_graphs: int = 1, num_graphs: int = 1,
runtime_shape: Optional[int] = None, runtime_shape: Optional[int] = None,
@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
# see https://github.com/pytorch/pytorch/issues/138980 # see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph) graph = copy.deepcopy(graph)
cache_data = compilation_config.inductor_hash_cache cache_data = vllm_backend.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data: if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before # we compiled this graph before
# so we can directly lookup the compiled graph via hash # so we can directly lookup the compiled graph via hash
@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
hash_str, example_inputs, True, False) hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, ( assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove" "Inductor cache lookup failed. Please remove"
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa f"the cache file {cache_data.cache_file_path} and try again." # noqa
) )
# Inductor calling convention (function signature): # Inductor calling convention (function signature):
@ -354,7 +355,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def __init__(self, module: torch.fx.GraphModule, def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str], vllm_config: VllmConfig, compile_submod_names: List[str], vllm_config: VllmConfig,
graph_pool): graph_pool, vllm_backend: "VllmBackend"):
super().__init__(module) super().__init__(module)
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode() self.fake_mode = detect_fake_mode()
@ -362,6 +363,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool self.graph_pool = graph_pool
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
def run(self, *args): def run(self, *args):
fake_args = [ fake_args = [
@ -389,6 +391,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
args, args,
self.compilation_config.inductor_compile_config, self.compilation_config.inductor_compile_config,
self.compilation_config, self.compilation_config,
self.vllm_backend,
graph_index=index, graph_index=index,
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
runtime_shape=None, runtime_shape=None,
@ -397,7 +400,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.module.__dict__[target] = PiecewiseBackend( self.module.__dict__[target] = PiecewiseBackend(
submod, self.vllm_config, self.graph_pool, index, submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape) compiled_graph_for_general_shape, self.vllm_backend)
compilation_counter.num_piecewise_capturable_graphs_seen += 1 compilation_counter.num_piecewise_capturable_graphs_seen += 1
@ -430,6 +433,7 @@ class VllmBackend:
post_grad_passes: Sequence[Callable] post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int] sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor] input_buffers: List[torch.Tensor]
inductor_hash_cache: InductorHashCache
def __init__( def __init__(
self, self,
@ -472,6 +476,53 @@ class VllmBackend:
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config = self.vllm_config
config_hash = vllm_config.compute_hash()
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
forward_code_files = list(
sorted(self.compilation_config.traced_files))
self.compilation_config.traced_files.clear()
logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
"\n".join(forward_code_files))
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
with open(filepath) as f:
hash_content.append(f.read())
import hashlib
code_hash = hashlib.md5(
"\n".join(hash_content).encode()).hexdigest()
# combine the two hashes to generate the cache dir
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info("Using cache directory: %s for vLLM's torch.compile",
cache_dir)
# when dynamo calls the backend, it means the bytecode # when dynamo calls the backend, it means the bytecode
# transform and analysis are done # transform and analysis are done
compilation_counter.num_graphs_seen += 1 compilation_counter.num_graphs_seen += 1
@ -507,8 +558,8 @@ class VllmBackend:
# propagate the split graph to the piecewise backend, # propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes # compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.vllm_config, self.vllm_config, self.graph_pool,
self.graph_pool).run(*example_inputs) self).run(*example_inputs)
self._called = True self._called = True
@ -577,7 +628,8 @@ class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int, graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: List[int], total_piecewise_compiles: int, sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable): compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
""" """
The backend for piecewise compilation. The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing. It mainly handles the compilation and cudagraph capturing.
@ -597,6 +649,7 @@ class PiecewiseBackend:
self.graph_pool = graph_pool self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.is_first_graph = piecewise_compile_index == 0 self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = ( self.is_last_graph = (
@ -634,7 +687,7 @@ class PiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile # no specific sizes to compile
# save the hash of the inductor graph for the next run # save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file() self.vllm_backend.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any: def __call__(self, *args) -> Any:
@ -662,6 +715,7 @@ class PiecewiseBackend:
args, args,
self.compilation_config.inductor_compile_config, self.compilation_config.inductor_compile_config,
self.compilation_config, self.compilation_config,
self.vllm_backend,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape, runtime_shape=runtime_shape,

View File

@ -1,8 +1,10 @@
import inspect import inspect
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
from unittest.mock import patch
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
@ -196,7 +198,31 @@ def _support_torch_compile(
# we need to control all the compilation of the model. # we need to control all the compilation of the model.
torch._dynamo.eval_frame.remove_from_cache( torch._dynamo.eval_frame.remove_from_cache(
self.original_code_object) self.original_code_object)
return self.compiled_callable(*args, **kwargs)
# collect all relevant files traced by Dynamo,
# so that the compilation cache can trigger re-compilation
# properly when any of these files change.
# 1. the file containing the top-level forward function
self.vllm_config.compilation_config.traced_files.add(
self.original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call
def patched_inline_call(parent, func, args, kwargs):
code = func.get_code()
self.vllm_config.compilation_config.traced_files.add(
code.co_filename)
return inline_call(parent, func, args, kwargs)
with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call):
output = self.compiled_callable(*args, **kwargs)
return output
# usually, capturing the model once is enough, and then we can # usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through # dispatch to the compiled code directly, without going through

View File

@ -3,7 +3,6 @@ import copy
import enum import enum
import hashlib import hashlib
import json import json
import os
import sys import sys
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
@ -2778,9 +2777,8 @@ class CompilationConfig(BaseModel):
# keep track of enabled and disabled custom ops # keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr
traced_files: Set[str] = PrivateAttr
compilation_time: float = PrivateAttr compilation_time: float = PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache: Any = PrivateAttr
# Per-model forward context # Per-model forward context
# Mainly used to store attention cls # Mainly used to store attention cls
@ -2818,6 +2816,7 @@ class CompilationConfig(BaseModel):
"compilation_time", "compilation_time",
"bs_to_padded_graph_size", "bs_to_padded_graph_size",
"pass_config", "pass_config",
"traced_files",
} }
return self.model_dump_json(exclude=exclude, exclude_unset=True) return self.model_dump_json(exclude=exclude, exclude_unset=True)
@ -2877,6 +2876,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter() self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter() self.disabled_custom_ops = Counter()
self.traced_files = set()
self.static_forward_context = {} self.static_forward_context = {}
self.compilation_time = 0.0 self.compilation_time = 0.0
@ -2899,29 +2899,6 @@ class CompilationConfig(BaseModel):
# merge with the config use_inductor # merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE assert self.level == CompilationLevel.PIECEWISE
if not self.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key = vllm_config.compute_hash()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
from vllm.compilation.backends import InductorHashCache
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
self.cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info(
"Using cache directory: %s for vLLM's torch.compile",
self.cache_dir)
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config) return VllmBackend(vllm_config)

View File

@ -1108,6 +1108,13 @@ class IntermediateTensors:
tensors: Dict[str, torch.Tensor] tensors: Dict[str, torch.Tensor]
def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
def __getitem__(self, key: Union[str, slice]): def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str): if isinstance(key, str):
return self.tensors[key] return self.tensors[key]