mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[torch.compile] consider relevant code in compilation cache (#11614)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
cfd3219f58
commit
f12141170a
@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
example_inputs,
|
example_inputs,
|
||||||
additional_inductor_config,
|
additional_inductor_config,
|
||||||
compilation_config: CompilationConfig,
|
compilation_config: CompilationConfig,
|
||||||
|
vllm_backend: "VllmBackend",
|
||||||
graph_index: int = 0,
|
graph_index: int = 0,
|
||||||
num_graphs: int = 1,
|
num_graphs: int = 1,
|
||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
# see https://github.com/pytorch/pytorch/issues/138980
|
# see https://github.com/pytorch/pytorch/issues/138980
|
||||||
graph = copy.deepcopy(graph)
|
graph = copy.deepcopy(graph)
|
||||||
|
|
||||||
cache_data = compilation_config.inductor_hash_cache
|
cache_data = vllm_backend.inductor_hash_cache
|
||||||
if (runtime_shape, graph_index) in cache_data:
|
if (runtime_shape, graph_index) in cache_data:
|
||||||
# we compiled this graph before
|
# we compiled this graph before
|
||||||
# so we can directly lookup the compiled graph via hash
|
# so we can directly lookup the compiled graph via hash
|
||||||
@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
hash_str, example_inputs, True, False)
|
hash_str, example_inputs, True, False)
|
||||||
assert inductor_compiled_graph is not None, (
|
assert inductor_compiled_graph is not None, (
|
||||||
"Inductor cache lookup failed. Please remove"
|
"Inductor cache lookup failed. Please remove"
|
||||||
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
|
f"the cache file {cache_data.cache_file_path} and try again." # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inductor calling convention (function signature):
|
# Inductor calling convention (function signature):
|
||||||
@ -354,7 +355,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
|
|
||||||
def __init__(self, module: torch.fx.GraphModule,
|
def __init__(self, module: torch.fx.GraphModule,
|
||||||
compile_submod_names: List[str], vllm_config: VllmConfig,
|
compile_submod_names: List[str], vllm_config: VllmConfig,
|
||||||
graph_pool):
|
graph_pool, vllm_backend: "VllmBackend"):
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
self.fake_mode = detect_fake_mode()
|
self.fake_mode = detect_fake_mode()
|
||||||
@ -362,6 +363,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.graph_pool = graph_pool
|
self.graph_pool = graph_pool
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.vllm_backend = vllm_backend
|
||||||
|
|
||||||
def run(self, *args):
|
def run(self, *args):
|
||||||
fake_args = [
|
fake_args = [
|
||||||
@ -389,6 +391,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.compilation_config.inductor_compile_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
|
self.vllm_backend,
|
||||||
graph_index=index,
|
graph_index=index,
|
||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
runtime_shape=None,
|
runtime_shape=None,
|
||||||
@ -397,7 +400,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
self.module.__dict__[target] = PiecewiseBackend(
|
self.module.__dict__[target] = PiecewiseBackend(
|
||||||
submod, self.vllm_config, self.graph_pool, index,
|
submod, self.vllm_config, self.graph_pool, index,
|
||||||
len(self.compile_submod_names), sym_shape_indices,
|
len(self.compile_submod_names), sym_shape_indices,
|
||||||
compiled_graph_for_general_shape)
|
compiled_graph_for_general_shape, self.vllm_backend)
|
||||||
|
|
||||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||||
|
|
||||||
@ -430,6 +433,7 @@ class VllmBackend:
|
|||||||
post_grad_passes: Sequence[Callable]
|
post_grad_passes: Sequence[Callable]
|
||||||
sym_tensor_indices: List[int]
|
sym_tensor_indices: List[int]
|
||||||
input_buffers: List[torch.Tensor]
|
input_buffers: List[torch.Tensor]
|
||||||
|
inductor_hash_cache: InductorHashCache
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -472,6 +476,53 @@ class VllmBackend:
|
|||||||
|
|
||||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||||
|
|
||||||
|
if not self.compilation_config.cache_dir:
|
||||||
|
# no provided cache dir, generate one based on the known factors
|
||||||
|
# that affects the compilation. if none of the factors change,
|
||||||
|
# the cache dir will be the same so that we can reuse the compiled
|
||||||
|
# graph.
|
||||||
|
|
||||||
|
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||||
|
# model is created)
|
||||||
|
vllm_config = self.vllm_config
|
||||||
|
config_hash = vllm_config.compute_hash()
|
||||||
|
|
||||||
|
# 2. factors come from the code files that are traced by Dynamo (
|
||||||
|
# it mainly summarizes how the model is used in forward pass)
|
||||||
|
forward_code_files = list(
|
||||||
|
sorted(self.compilation_config.traced_files))
|
||||||
|
self.compilation_config.traced_files.clear()
|
||||||
|
logger.debug(
|
||||||
|
"Traced files (to be considered for compilation cache):\n%s",
|
||||||
|
"\n".join(forward_code_files))
|
||||||
|
hash_content = []
|
||||||
|
for filepath in forward_code_files:
|
||||||
|
hash_content.append(filepath)
|
||||||
|
with open(filepath) as f:
|
||||||
|
hash_content.append(f.read())
|
||||||
|
import hashlib
|
||||||
|
code_hash = hashlib.md5(
|
||||||
|
"\n".join(hash_content).encode()).hexdigest()
|
||||||
|
|
||||||
|
# combine the two hashes to generate the cache dir
|
||||||
|
hash_key = hashlib.md5(
|
||||||
|
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
|
||||||
|
cache_dir = os.path.join(
|
||||||
|
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
|
||||||
|
f"rank_{vllm_config.parallel_config.rank}")
|
||||||
|
else:
|
||||||
|
cache_dir = self.compilation_config.cache_dir
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||||
|
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
||||||
|
cache_dir, disabled=disabled)
|
||||||
|
if disabled:
|
||||||
|
logger.info("vLLM's torch.compile cache is disabled.")
|
||||||
|
else:
|
||||||
|
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||||
|
cache_dir)
|
||||||
|
|
||||||
# when dynamo calls the backend, it means the bytecode
|
# when dynamo calls the backend, it means the bytecode
|
||||||
# transform and analysis are done
|
# transform and analysis are done
|
||||||
compilation_counter.num_graphs_seen += 1
|
compilation_counter.num_graphs_seen += 1
|
||||||
@ -507,8 +558,8 @@ class VllmBackend:
|
|||||||
# propagate the split graph to the piecewise backend,
|
# propagate the split graph to the piecewise backend,
|
||||||
# compile submodules with symbolic shapes
|
# compile submodules with symbolic shapes
|
||||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||||
self.vllm_config,
|
self.vllm_config, self.graph_pool,
|
||||||
self.graph_pool).run(*example_inputs)
|
self).run(*example_inputs)
|
||||||
|
|
||||||
self._called = True
|
self._called = True
|
||||||
|
|
||||||
@ -577,7 +628,8 @@ class PiecewiseBackend:
|
|||||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||||
graph_pool: Any, piecewise_compile_index: int,
|
graph_pool: Any, piecewise_compile_index: int,
|
||||||
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
||||||
compiled_graph_for_general_shape: Callable):
|
compiled_graph_for_general_shape: Callable,
|
||||||
|
vllm_backend: VllmBackend):
|
||||||
"""
|
"""
|
||||||
The backend for piecewise compilation.
|
The backend for piecewise compilation.
|
||||||
It mainly handles the compilation and cudagraph capturing.
|
It mainly handles the compilation and cudagraph capturing.
|
||||||
@ -597,6 +649,7 @@ class PiecewiseBackend:
|
|||||||
self.graph_pool = graph_pool
|
self.graph_pool = graph_pool
|
||||||
self.piecewise_compile_index = piecewise_compile_index
|
self.piecewise_compile_index = piecewise_compile_index
|
||||||
self.total_piecewise_compiles = total_piecewise_compiles
|
self.total_piecewise_compiles = total_piecewise_compiles
|
||||||
|
self.vllm_backend = vllm_backend
|
||||||
|
|
||||||
self.is_first_graph = piecewise_compile_index == 0
|
self.is_first_graph = piecewise_compile_index == 0
|
||||||
self.is_last_graph = (
|
self.is_last_graph = (
|
||||||
@ -634,7 +687,7 @@ class PiecewiseBackend:
|
|||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
# no specific sizes to compile
|
# no specific sizes to compile
|
||||||
# save the hash of the inductor graph for the next run
|
# save the hash of the inductor graph for the next run
|
||||||
self.compilation_config.inductor_hash_cache.save_to_file()
|
self.vllm_backend.inductor_hash_cache.save_to_file()
|
||||||
end_monitoring_torch_compile(self.vllm_config)
|
end_monitoring_torch_compile(self.vllm_config)
|
||||||
|
|
||||||
def __call__(self, *args) -> Any:
|
def __call__(self, *args) -> Any:
|
||||||
@ -662,6 +715,7 @@ class PiecewiseBackend:
|
|||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.compilation_config.inductor_compile_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
|
self.vllm_backend,
|
||||||
graph_index=self.piecewise_compile_index,
|
graph_index=self.piecewise_compile_index,
|
||||||
num_graphs=self.total_piecewise_compiles,
|
num_graphs=self.total_piecewise_compiles,
|
||||||
runtime_shape=runtime_shape,
|
runtime_shape=runtime_shape,
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
@ -196,7 +198,31 @@ def _support_torch_compile(
|
|||||||
# we need to control all the compilation of the model.
|
# we need to control all the compilation of the model.
|
||||||
torch._dynamo.eval_frame.remove_from_cache(
|
torch._dynamo.eval_frame.remove_from_cache(
|
||||||
self.original_code_object)
|
self.original_code_object)
|
||||||
return self.compiled_callable(*args, **kwargs)
|
|
||||||
|
# collect all relevant files traced by Dynamo,
|
||||||
|
# so that the compilation cache can trigger re-compilation
|
||||||
|
# properly when any of these files change.
|
||||||
|
|
||||||
|
# 1. the file containing the top-level forward function
|
||||||
|
self.vllm_config.compilation_config.traced_files.add(
|
||||||
|
self.original_code_object.co_filename)
|
||||||
|
|
||||||
|
# 2. every time Dynamo sees a function call, it will inline
|
||||||
|
# the function by calling InliningInstructionTranslator.inline_call
|
||||||
|
# we hijack this function to know all the functions called
|
||||||
|
# during Dynamo tracing, and their corresponding files
|
||||||
|
inline_call = InliningInstructionTranslator.inline_call
|
||||||
|
|
||||||
|
def patched_inline_call(parent, func, args, kwargs):
|
||||||
|
code = func.get_code()
|
||||||
|
self.vllm_config.compilation_config.traced_files.add(
|
||||||
|
code.co_filename)
|
||||||
|
return inline_call(parent, func, args, kwargs)
|
||||||
|
|
||||||
|
with patch.object(InliningInstructionTranslator, 'inline_call',
|
||||||
|
patched_inline_call):
|
||||||
|
output = self.compiled_callable(*args, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
# usually, capturing the model once is enough, and then we can
|
||||||
# dispatch to the compiled code directly, without going through
|
# dispatch to the compiled code directly, without going through
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import copy
|
|||||||
import enum
|
import enum
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -2778,9 +2777,8 @@ class CompilationConfig(BaseModel):
|
|||||||
# keep track of enabled and disabled custom ops
|
# keep track of enabled and disabled custom ops
|
||||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||||
|
traced_files: Set[str] = PrivateAttr
|
||||||
compilation_time: float = PrivateAttr
|
compilation_time: float = PrivateAttr
|
||||||
# should be InductorHashCache, but Pydantic does not support it
|
|
||||||
inductor_hash_cache: Any = PrivateAttr
|
|
||||||
|
|
||||||
# Per-model forward context
|
# Per-model forward context
|
||||||
# Mainly used to store attention cls
|
# Mainly used to store attention cls
|
||||||
@ -2818,6 +2816,7 @@ class CompilationConfig(BaseModel):
|
|||||||
"compilation_time",
|
"compilation_time",
|
||||||
"bs_to_padded_graph_size",
|
"bs_to_padded_graph_size",
|
||||||
"pass_config",
|
"pass_config",
|
||||||
|
"traced_files",
|
||||||
}
|
}
|
||||||
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
||||||
|
|
||||||
@ -2877,6 +2876,7 @@ class CompilationConfig(BaseModel):
|
|||||||
|
|
||||||
self.enabled_custom_ops = Counter()
|
self.enabled_custom_ops = Counter()
|
||||||
self.disabled_custom_ops = Counter()
|
self.disabled_custom_ops = Counter()
|
||||||
|
self.traced_files = set()
|
||||||
self.static_forward_context = {}
|
self.static_forward_context = {}
|
||||||
self.compilation_time = 0.0
|
self.compilation_time = 0.0
|
||||||
|
|
||||||
@ -2899,29 +2899,6 @@ class CompilationConfig(BaseModel):
|
|||||||
# merge with the config use_inductor
|
# merge with the config use_inductor
|
||||||
assert self.level == CompilationLevel.PIECEWISE
|
assert self.level == CompilationLevel.PIECEWISE
|
||||||
|
|
||||||
if not self.cache_dir:
|
|
||||||
# no provided cache dir, generate one based on the known factors
|
|
||||||
# that affects the compilation. if none of the factors change,
|
|
||||||
# the cache dir will be the same so that we can reuse the compiled
|
|
||||||
# graph.
|
|
||||||
hash_key = vllm_config.compute_hash()
|
|
||||||
cache_dir = os.path.join(
|
|
||||||
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
|
|
||||||
f"rank_{vllm_config.parallel_config.rank}")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
self.cache_dir = cache_dir
|
|
||||||
|
|
||||||
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
|
||||||
from vllm.compilation.backends import InductorHashCache
|
|
||||||
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
|
||||||
self.cache_dir, disabled=disabled)
|
|
||||||
if disabled:
|
|
||||||
logger.info("vLLM's torch.compile cache is disabled.")
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Using cache directory: %s for vLLM's torch.compile",
|
|
||||||
self.cache_dir)
|
|
||||||
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
return VllmBackend(vllm_config)
|
return VllmBackend(vllm_config)
|
||||||
|
|
||||||
|
|||||||
@ -1108,6 +1108,13 @@ class IntermediateTensors:
|
|||||||
|
|
||||||
tensors: Dict[str, torch.Tensor]
|
tensors: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(self, tensors):
|
||||||
|
# manually define this function, so that
|
||||||
|
# Dynamo knows `IntermediateTensors()` comes from this file.
|
||||||
|
# Otherwise, dataclass will generate this function by evaluating
|
||||||
|
# a string, and we will lose the information about the source file.
|
||||||
|
self.tensors = tensors
|
||||||
|
|
||||||
def __getitem__(self, key: Union[str, slice]):
|
def __getitem__(self, key: Union[str, slice]):
|
||||||
if isinstance(key, str):
|
if isinstance(key, str):
|
||||||
return self.tensors[key]
|
return self.tensors[key]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user