mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[torch.compile] PyTorch 2.6 and nightly compatibility (#12393)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
85ac82d228
commit
09b95e36ab
@ -92,7 +92,7 @@ def test_simple_piecewise_compile():
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||
num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
|
||||
@ -322,7 +322,7 @@ def test_toy_llama():
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_inductor_compilations=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_caputured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=False))
|
||||
@ -332,7 +332,7 @@ def test_toy_llama():
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_piecewise_capturable_graphs_seen=1,
|
||||
num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=
|
||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
@ -345,7 +345,7 @@ def test_toy_llama():
|
||||
1, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=1 +
|
||||
llama_config.num_layers, # 1 + num_layers
|
||||
num_inductor_compilations=1 +
|
||||
num_backend_compilations=1 +
|
||||
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=2 *
|
||||
(1 + llama_config.num_layers
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from unittest.mock import patch
|
||||
@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
from .compiler_interface import EagerAdaptor, InductorAdaptor
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .monitor import end_monitoring_torch_compile
|
||||
@ -27,306 +26,128 @@ from .pass_manager import PostGradPassManager
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InductorArtifact:
|
||||
hash_str: str = ""
|
||||
file_path: str = ""
|
||||
|
||||
|
||||
class InductorHashCache:
|
||||
class CompilerManager:
|
||||
"""
|
||||
Disk format: a Python list of tuples, each tuple is
|
||||
(runtime_shape, graph_index, hash_str, file_path)
|
||||
We use list of tuple for readability.
|
||||
A manager to manage the compilation process, including
|
||||
caching the compiled graph, loading the compiled graph,
|
||||
and compiling the graph.
|
||||
|
||||
In-memory format: a defaultdict of dict, where the key is
|
||||
runtime_shape, and the value is a dict of graph_index to hash_str.
|
||||
The cache is a dict mapping
|
||||
`(runtime_shape, graph_index, backend_name)`
|
||||
to `any_data` returned from the compiler.
|
||||
|
||||
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
|
||||
we don't use json here because json doesn't support int as key.
|
||||
|
||||
TODO: better off-the-shelf solution to serialize the data?
|
||||
When serializing the cache, we save it to a Python file
|
||||
for readability. We don't use json here because json doesn't
|
||||
support int as key.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str, disabled: bool = False):
|
||||
self.cache: Dict[Optional[int],
|
||||
Dict[int, InductorArtifact]] = defaultdict(dict)
|
||||
self.disabled = disabled
|
||||
def __init__(self, use_inductor: bool):
|
||||
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
|
||||
cls = InductorAdaptor if use_inductor else EagerAdaptor
|
||||
self.compiler = cls()
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir,
|
||||
"inductor_hash_cache.py")
|
||||
if disabled:
|
||||
return
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
if os.path.exists(self.cache_file_path):
|
||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||
|
||||
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||
# load the cache from the file
|
||||
with open(self.cache_file_path) as f:
|
||||
self.deserialize(f.read())
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
|
||||
def deserialize(self, data: str):
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
list_data = ast.literal_eval(data)
|
||||
for item in list_data:
|
||||
runtime_shape = item[0]
|
||||
graph_index = item[1]
|
||||
hash_str = item[2]
|
||||
# for compatibility of old version,
|
||||
# where we don't have file_path.
|
||||
# NOTE: after running the new code, the file_path
|
||||
# will be updated.
|
||||
file_path = "" if len(item) == 3 else item[3]
|
||||
self.cache[runtime_shape][graph_index] = InductorArtifact(
|
||||
hash_str=hash_str, file_path=file_path)
|
||||
|
||||
def serialize(self) -> str:
|
||||
data = []
|
||||
for runtime_shape, value in self.cache.items():
|
||||
for graph_index, inductor_artifact in value.items():
|
||||
data.append(
|
||||
(runtime_shape, graph_index, inductor_artifact.hash_str,
|
||||
inductor_artifact.file_path))
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
return printer.pformat(data)
|
||||
self.compiler.initialize_cache(cache_dir=cache_dir,
|
||||
disable_cache=disable_cache)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disabled:
|
||||
if self.disable_cache:
|
||||
return
|
||||
with open(self.cache_file_path, "w") as f:
|
||||
f.write(self.serialize())
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
data = printer.pformat(self.cache)
|
||||
f.write(data)
|
||||
|
||||
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
|
||||
if self.disabled:
|
||||
return False
|
||||
runtime_shape, graph_index = key
|
||||
return runtime_shape in self.cache and graph_index in self.cache[
|
||||
runtime_shape]
|
||||
|
||||
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
|
||||
if self.disabled:
|
||||
raise KeyError("cannot read from disabled cache")
|
||||
runtime_shape, graph_index = key
|
||||
return self.cache[runtime_shape][graph_index]
|
||||
|
||||
def __setitem__(self, key: Tuple[Optional[int], int],
|
||||
value: InductorArtifact):
|
||||
# setitem for disabled cache is fine, because we
|
||||
# don't actually write to the disk
|
||||
runtime_shape, graph_index = key
|
||||
self.cache[runtime_shape][graph_index] = value
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: List[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
|
||||
def wrap_inductor(graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
vllm_backend: "VllmBackend",
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None,
|
||||
use_inductor: bool = True) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
if not use_inductor:
|
||||
return graph
|
||||
|
||||
compilation_counter.num_inductor_compilations += 1
|
||||
|
||||
from torch._inductor import config
|
||||
current_config = config.get_config_copy()
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
if additional_inductor_config is not None:
|
||||
current_config.update(additional_inductor_config)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
current_config["max_autotune"] = True
|
||||
current_config["coordinate_descent_tuning"] = True
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
cache_data = vllm_backend.inductor_hash_cache
|
||||
if (runtime_shape, graph_index) in cache_data:
|
||||
# we compiled this graph before
|
||||
# so we can directly lookup the compiled graph via hash
|
||||
inductor_artifact = cache_data[(runtime_shape, graph_index)]
|
||||
hash_str = inductor_artifact.hash_str
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
logger.info(
|
||||
"Directly lookup the graph for shape %s from the cache",
|
||||
str(runtime_shape)) # noqa
|
||||
def load(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(handle, graph, example_inputs,
|
||||
graph_index, runtime_shape)
|
||||
logger.debug(
|
||||
"directly lookup the %s-th graph for shape %s via hash %s",
|
||||
graph_index, str(runtime_shape), hash_str)
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache file {cache_data.cache_file_path} and try again." # noqa
|
||||
)
|
||||
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||
"Directly load the %s-th graph for shape %s from %s via "
|
||||
"handle %s", graph_index, str(runtime_shape), self.compiler.name,
|
||||
handle)
|
||||
return compiled_graph
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
def compile(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
compilation_counter.num_backend_compilations += 1
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index,
|
||||
runtime_shape)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
logger.info("Directly load the compiled graph for shape %s "
|
||||
"from the cache", str(runtime_shape)) # noqa
|
||||
return compiled_graph
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
# we need to compile it
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
# store the artifact in the cache
|
||||
if handle is not None:
|
||||
self.cache[(runtime_shape, graph_index,
|
||||
self.compiler.name)] = handle
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
logger.info("Cache the graph of shape %s for later use",
|
||||
str(runtime_shape))
|
||||
logger.debug(
|
||||
"store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index, str(runtime_shape), self.compiler.name, handle)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for general shape takes %.2f s",
|
||||
elapsed)
|
||||
else:
|
||||
return graph_output[0]
|
||||
else:
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
logger.info("Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape, elapsed)
|
||||
|
||||
inductor_artifact = InductorArtifact()
|
||||
from torch._inductor.codecache import (FxGraphCache,
|
||||
compiled_fx_graph_hash)
|
||||
original_load = FxGraphCache.load
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||
return inductor_compiled_graph
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
inductor_artifact.hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if not cache_data.disabled:
|
||||
# compilation cache is enabled, patch several functions
|
||||
|
||||
# hijack to get the compiled graph itself
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache.load",
|
||||
hijack_load))
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash))
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache))
|
||||
|
||||
compiled_graph = compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=current_config)
|
||||
# store the inductor_artifact in the cache
|
||||
cache_data[(runtime_shape, graph_index)] = inductor_artifact
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
logger.info("Cache the graph of shape %s for later use",
|
||||
str(runtime_shape))
|
||||
logger.debug(
|
||||
"store the %s-th graph for shape %s via hash %s from file %s",
|
||||
graph_index, str(runtime_shape), inductor_artifact.hash_str,
|
||||
inductor_artifact.file_path)
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for general shape takes %.2f s",
|
||||
elapsed)
|
||||
else:
|
||||
logger.info("Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape, elapsed)
|
||||
|
||||
return compiled_graph
|
||||
return compiled_graph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
global compilation_start_time
|
||||
compiled_graph_for_general_shape = wrap_inductor(
|
||||
compiled_graph_for_general_shape = self.vllm_backend.\
|
||||
compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
self.vllm_backend,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
use_inductor=self.compilation_config.use_inductor)
|
||||
runtime_shape=None)
|
||||
|
||||
self.module.__dict__[target] = PiecewiseBackend(
|
||||
submod, self.vllm_config, self.graph_pool, index,
|
||||
@ -483,7 +303,7 @@ class VllmBackend:
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: List[int]
|
||||
input_buffers: List[torch.Tensor]
|
||||
inductor_hash_cache: InductorHashCache
|
||||
compiler_manager: CompilerManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -507,6 +327,9 @@ class VllmBackend:
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.compiler_manager: CompilerManager = CompilerManager(
|
||||
self.compilation_config.use_inductor)
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
@ -533,9 +356,11 @@ class VllmBackend:
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
|
||||
factors = []
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
@ -553,10 +378,15 @@ class VllmBackend:
|
||||
import hashlib
|
||||
code_hash = hashlib.md5(
|
||||
"\n".join(hash_content).encode()).hexdigest()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||
|
||||
# combine the two hashes to generate the cache dir
|
||||
hash_key = hashlib.md5(
|
||||
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
@ -570,15 +400,16 @@ class VllmBackend:
|
||||
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
||||
local_cache_dir, disabled=disabled)
|
||||
if disabled:
|
||||
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
|
||||
if disable_cache:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
else:
|
||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||
local_cache_dir)
|
||||
|
||||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
@ -759,7 +590,7 @@ class PiecewiseBackend:
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.inductor_hash_cache.save_to_file()
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
@ -782,16 +613,14 @@ class PiecewiseBackend:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = wrap_inductor(
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
self.vllm_backend,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape,
|
||||
use_inductor=self.compilation_config.use_inductor)
|
||||
runtime_shape=runtime_shape)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
|
||||
340
vllm/compilation/compiler_interface.py
Normal file
340
vllm/compilation/compiler_interface.py
Normal file
@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the VLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See :meth:`VllmConfig.compute_hash` to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||
support one variable shape for all inputs, which is the batchsize
|
||||
(number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: List[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5 and 2.6.
|
||||
"""
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors: List[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a sub-directory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
from torch._inductor import config
|
||||
current_config = config.get_config_copy()
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
current_config["max_autotune"] = True
|
||||
current_config["coordinate_descent_tuning"] = True
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import (FxGraphCache,
|
||||
compiled_fx_graph_hash)
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
nonlocal file_path
|
||||
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||
return inductor_compiled_graph
|
||||
|
||||
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
|
||||
elif torch.__version__ >= "2.6":
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(
|
||||
*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# hijack to get the compiled graph itself
|
||||
if original_load_name is not None:
|
||||
stack.enter_context(patch(original_load_name, hijack_load))
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash))
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache))
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config)
|
||||
|
||||
assert hash_str is not None, (
|
||||
"failed to get the hash of the compiled graph")
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph")
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
elif torch.__version__ >= "2.6":
|
||||
from torch._inductor.output_code import (
|
||||
CompiledFxGraphConstantsWithGm)
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
@ -13,7 +13,7 @@ class CompilationCounter:
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_inductor_compilations: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
num_cudagraph_caputured: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
|
||||
@ -13,7 +13,6 @@ from torch import fx
|
||||
class InductorPass(ABC):
|
||||
"""
|
||||
General custom inductor pass interface.
|
||||
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PostGradPassManager:
|
||||
class PlaceHolder:
|
||||
pass
|
||||
|
||||
|
||||
if torch.__version__ < "2.6":
|
||||
Parent = PlaceHolder # type: ignore
|
||||
else:
|
||||
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
|
||||
|
||||
|
||||
class PostGradPassManager(Parent):
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
@ -55,6 +66,9 @@ class PostGradPassManager:
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self):
|
||||
return self.__getstate__()
|
||||
|
||||
def __getstate__(self) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Custom pickling for the pass manager, as some passes cannot be pickled.
|
||||
|
||||
@ -3072,15 +3072,6 @@ class VllmConfig:
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: List[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
|
||||
# summarize vllm config
|
||||
vllm_factors: List[Any] = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user