mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[torch.compile] consider relevant code in compilation cache (#11614)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
cfd3219f58
commit
f12141170a
@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
vllm_backend: "VllmBackend",
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None,
|
||||
@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
cache_data = compilation_config.inductor_hash_cache
|
||||
cache_data = vllm_backend.inductor_hash_cache
|
||||
if (runtime_shape, graph_index) in cache_data:
|
||||
# we compiled this graph before
|
||||
# so we can directly lookup the compiled graph via hash
|
||||
@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
|
||||
hash_str, example_inputs, True, False)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
|
||||
f"the cache file {cache_data.cache_file_path} and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
@ -354,7 +355,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: List[str], vllm_config: VllmConfig,
|
||||
graph_pool):
|
||||
graph_pool, vllm_backend: "VllmBackend"):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
self.fake_mode = detect_fake_mode()
|
||||
@ -362,6 +363,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
@ -389,6 +391,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
self.vllm_backend,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
@ -397,7 +400,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
self.module.__dict__[target] = PiecewiseBackend(
|
||||
submod, self.vllm_config, self.graph_pool, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_general_shape)
|
||||
compiled_graph_for_general_shape, self.vllm_backend)
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
@ -430,6 +433,7 @@ class VllmBackend:
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: List[int]
|
||||
input_buffers: List[torch.Tensor]
|
||||
inductor_hash_cache: InductorHashCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -472,6 +476,53 @@ class VllmBackend:
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
|
||||
if not self.compilation_config.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
vllm_config = self.vllm_config
|
||||
config_hash = vllm_config.compute_hash()
|
||||
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
forward_code_files = list(
|
||||
sorted(self.compilation_config.traced_files))
|
||||
self.compilation_config.traced_files.clear()
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s",
|
||||
"\n".join(forward_code_files))
|
||||
hash_content = []
|
||||
for filepath in forward_code_files:
|
||||
hash_content.append(filepath)
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
import hashlib
|
||||
code_hash = hashlib.md5(
|
||||
"\n".join(hash_content).encode()).hexdigest()
|
||||
|
||||
# combine the two hashes to generate the cache dir
|
||||
hash_key = hashlib.md5(
|
||||
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
else:
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
||||
cache_dir, disabled=disabled)
|
||||
if disabled:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
else:
|
||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||
cache_dir)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
@ -507,8 +558,8 @@ class VllmBackend:
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.vllm_config,
|
||||
self.graph_pool).run(*example_inputs)
|
||||
self.vllm_config, self.graph_pool,
|
||||
self).run(*example_inputs)
|
||||
|
||||
self._called = True
|
||||
|
||||
@ -577,7 +628,8 @@ class PiecewiseBackend:
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
||||
compiled_graph_for_general_shape: Callable):
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and cudagraph capturing.
|
||||
@ -597,6 +649,7 @@ class PiecewiseBackend:
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
@ -634,7 +687,7 @@ class PiecewiseBackend:
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.compilation_config.inductor_hash_cache.save_to_file()
|
||||
self.vllm_backend.inductor_hash_cache.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
@ -662,6 +715,7 @@ class PiecewiseBackend:
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
self.vllm_backend,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape,
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
@ -196,7 +198,31 @@ def _support_torch_compile(
|
||||
# we need to control all the compilation of the model.
|
||||
torch._dynamo.eval_frame.remove_from_cache(
|
||||
self.original_code_object)
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
self.original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call
|
||||
|
||||
def patched_inline_call(parent, func, args, kwargs):
|
||||
code = func.get_code()
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
code.co_filename)
|
||||
return inline_call(parent, func, args, kwargs)
|
||||
|
||||
with patch.object(InliningInstructionTranslator, 'inline_call',
|
||||
patched_inline_call):
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
return output
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
|
||||
@ -3,7 +3,6 @@ import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
@ -2778,9 +2777,8 @@ class CompilationConfig(BaseModel):
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||
traced_files: Set[str] = PrivateAttr
|
||||
compilation_time: float = PrivateAttr
|
||||
# should be InductorHashCache, but Pydantic does not support it
|
||||
inductor_hash_cache: Any = PrivateAttr
|
||||
|
||||
# Per-model forward context
|
||||
# Mainly used to store attention cls
|
||||
@ -2818,6 +2816,7 @@ class CompilationConfig(BaseModel):
|
||||
"compilation_time",
|
||||
"bs_to_padded_graph_size",
|
||||
"pass_config",
|
||||
"traced_files",
|
||||
}
|
||||
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
||||
|
||||
@ -2877,6 +2876,7 @@ class CompilationConfig(BaseModel):
|
||||
|
||||
self.enabled_custom_ops = Counter()
|
||||
self.disabled_custom_ops = Counter()
|
||||
self.traced_files = set()
|
||||
self.static_forward_context = {}
|
||||
self.compilation_time = 0.0
|
||||
|
||||
@ -2899,29 +2899,6 @@ class CompilationConfig(BaseModel):
|
||||
# merge with the config use_inductor
|
||||
assert self.level == CompilationLevel.PIECEWISE
|
||||
|
||||
if not self.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
hash_key = vllm_config.compute_hash()
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
from vllm.compilation.backends import InductorHashCache
|
||||
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
||||
self.cache_dir, disabled=disabled)
|
||||
if disabled:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
else:
|
||||
logger.info(
|
||||
"Using cache directory: %s for vLLM's torch.compile",
|
||||
self.cache_dir)
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
return VllmBackend(vllm_config)
|
||||
|
||||
|
||||
@ -1108,6 +1108,13 @@ class IntermediateTensors:
|
||||
|
||||
tensors: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(self, tensors):
|
||||
# manually define this function, so that
|
||||
# Dynamo knows `IntermediateTensors()` comes from this file.
|
||||
# Otherwise, dataclass will generate this function by evaluating
|
||||
# a string, and we will lose the information about the source file.
|
||||
self.tensors = tensors
|
||||
|
||||
def __getitem__(self, key: Union[str, slice]):
|
||||
if isinstance(key, str):
|
||||
return self.tensors[key]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user