mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:56:07 +08:00
[torch.compile] PyTorch 2.6 and nightly compatibility (#12393)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
85ac82d228
commit
09b95e36ab
@ -92,7 +92,7 @@ def test_simple_piecewise_compile():
|
|||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||||
num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_caputured=
|
num_cudagraph_caputured=
|
||||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
|
|||||||
@ -322,7 +322,7 @@ def test_toy_llama():
|
|||||||
num_graphs_seen=0,
|
num_graphs_seen=0,
|
||||||
num_piecewise_graphs_seen=0,
|
num_piecewise_graphs_seen=0,
|
||||||
num_piecewise_capturable_graphs_seen=0,
|
num_piecewise_capturable_graphs_seen=0,
|
||||||
num_inductor_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_caputured=0,
|
num_cudagraph_caputured=0,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, use_compile=False))
|
outputs.append(run_model(llama_config, use_compile=False))
|
||||||
@ -332,7 +332,7 @@ def test_toy_llama():
|
|||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=1,
|
num_piecewise_graphs_seen=1,
|
||||||
num_piecewise_capturable_graphs_seen=1,
|
num_piecewise_capturable_graphs_seen=1,
|
||||||
num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_caputured=
|
num_cudagraph_caputured=
|
||||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
@ -345,7 +345,7 @@ def test_toy_llama():
|
|||||||
1, # 2 * num_layers + 1
|
1, # 2 * num_layers + 1
|
||||||
num_piecewise_capturable_graphs_seen=1 +
|
num_piecewise_capturable_graphs_seen=1 +
|
||||||
llama_config.num_layers, # 1 + num_layers
|
llama_config.num_layers, # 1 + num_layers
|
||||||
num_inductor_compilations=1 +
|
num_backend_compilations=1 +
|
||||||
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_caputured=2 *
|
num_cudagraph_caputured=2 *
|
||||||
(1 + llama_config.num_layers
|
(1 + llama_config.num_layers
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import copy
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import weak_ref_tensors
|
from vllm.utils import weak_ref_tensors
|
||||||
|
|
||||||
|
from .compiler_interface import EagerAdaptor, InductorAdaptor
|
||||||
from .counter import compilation_counter
|
from .counter import compilation_counter
|
||||||
from .inductor_pass import InductorPass
|
from .inductor_pass import InductorPass
|
||||||
from .monitor import end_monitoring_torch_compile
|
from .monitor import end_monitoring_torch_compile
|
||||||
@ -27,293 +26,115 @@ from .pass_manager import PostGradPassManager
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
class CompilerManager:
|
||||||
class InductorArtifact:
|
|
||||||
hash_str: str = ""
|
|
||||||
file_path: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class InductorHashCache:
|
|
||||||
"""
|
"""
|
||||||
Disk format: a Python list of tuples, each tuple is
|
A manager to manage the compilation process, including
|
||||||
(runtime_shape, graph_index, hash_str, file_path)
|
caching the compiled graph, loading the compiled graph,
|
||||||
We use list of tuple for readability.
|
and compiling the graph.
|
||||||
|
|
||||||
In-memory format: a defaultdict of dict, where the key is
|
The cache is a dict mapping
|
||||||
runtime_shape, and the value is a dict of graph_index to hash_str.
|
`(runtime_shape, graph_index, backend_name)`
|
||||||
|
to `any_data` returned from the compiler.
|
||||||
|
|
||||||
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
|
When serializing the cache, we save it to a Python file
|
||||||
we don't use json here because json doesn't support int as key.
|
for readability. We don't use json here because json doesn't
|
||||||
|
support int as key.
|
||||||
TODO: better off-the-shelf solution to serialize the data?
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_dir: str, disabled: bool = False):
|
def __init__(self, use_inductor: bool):
|
||||||
self.cache: Dict[Optional[int],
|
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
|
||||||
Dict[int, InductorArtifact]] = defaultdict(dict)
|
cls = InductorAdaptor if use_inductor else EagerAdaptor
|
||||||
self.disabled = disabled
|
self.compiler = cls()
|
||||||
|
|
||||||
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
|
return self.compiler.compute_hash(vllm_config)
|
||||||
|
|
||||||
|
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||||
|
self.disable_cache = disable_cache
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.cache_file_path = os.path.join(cache_dir,
|
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||||
"inductor_hash_cache.py")
|
|
||||||
if disabled:
|
|
||||||
return
|
|
||||||
# set flags so that Inductor and Triton store their cache
|
|
||||||
# in the cache_dir, then users only need to copy the cache_dir
|
|
||||||
# to another machine to reuse the cache.
|
|
||||||
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
|
||||||
os.makedirs(inductor_cache, exist_ok=True)
|
|
||||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
|
||||||
triton_cache = os.path.join(cache_dir, "triton_cache")
|
|
||||||
os.makedirs(triton_cache, exist_ok=True)
|
|
||||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
|
||||||
if os.path.exists(self.cache_file_path):
|
|
||||||
with open(self.cache_file_path) as f:
|
|
||||||
self.deserialize(f.read())
|
|
||||||
|
|
||||||
def deserialize(self, data: str):
|
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||||
|
# load the cache from the file
|
||||||
|
with open(self.cache_file_path) as f:
|
||||||
# we use ast.literal_eval to parse the data
|
# we use ast.literal_eval to parse the data
|
||||||
# because it is a safe way to parse Python literals.
|
# because it is a safe way to parse Python literals.
|
||||||
# do not use eval(), it is unsafe.
|
# do not use eval(), it is unsafe.
|
||||||
list_data = ast.literal_eval(data)
|
self.cache = ast.literal_eval(f.read())
|
||||||
for item in list_data:
|
|
||||||
runtime_shape = item[0]
|
|
||||||
graph_index = item[1]
|
|
||||||
hash_str = item[2]
|
|
||||||
# for compatibility of old version,
|
|
||||||
# where we don't have file_path.
|
|
||||||
# NOTE: after running the new code, the file_path
|
|
||||||
# will be updated.
|
|
||||||
file_path = "" if len(item) == 3 else item[3]
|
|
||||||
self.cache[runtime_shape][graph_index] = InductorArtifact(
|
|
||||||
hash_str=hash_str, file_path=file_path)
|
|
||||||
|
|
||||||
def serialize(self) -> str:
|
self.compiler.initialize_cache(cache_dir=cache_dir,
|
||||||
data = []
|
disable_cache=disable_cache)
|
||||||
for runtime_shape, value in self.cache.items():
|
|
||||||
for graph_index, inductor_artifact in value.items():
|
|
||||||
data.append(
|
|
||||||
(runtime_shape, graph_index, inductor_artifact.hash_str,
|
|
||||||
inductor_artifact.file_path))
|
|
||||||
printer = pprint.PrettyPrinter(indent=4)
|
|
||||||
return printer.pformat(data)
|
|
||||||
|
|
||||||
def save_to_file(self):
|
def save_to_file(self):
|
||||||
if self.disabled:
|
if self.disable_cache:
|
||||||
return
|
return
|
||||||
with open(self.cache_file_path, "w") as f:
|
with open(self.cache_file_path, "w") as f:
|
||||||
f.write(self.serialize())
|
printer = pprint.PrettyPrinter(indent=4)
|
||||||
|
data = printer.pformat(self.cache)
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
|
def load(self,
|
||||||
if self.disabled:
|
graph: fx.GraphModule,
|
||||||
return False
|
example_inputs: List[Any],
|
||||||
runtime_shape, graph_index = key
|
graph_index: int,
|
||||||
return runtime_shape in self.cache and graph_index in self.cache[
|
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
||||||
runtime_shape]
|
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||||
|
return None
|
||||||
|
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||||
|
compiled_graph = self.compiler.load(handle, graph, example_inputs,
|
||||||
|
graph_index, runtime_shape)
|
||||||
|
logger.debug(
|
||||||
|
"Directly load the %s-th graph for shape %s from %s via "
|
||||||
|
"handle %s", graph_index, str(runtime_shape), self.compiler.name,
|
||||||
|
handle)
|
||||||
|
return compiled_graph
|
||||||
|
|
||||||
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
|
def compile(self,
|
||||||
if self.disabled:
|
graph: fx.GraphModule,
|
||||||
raise KeyError("cannot read from disabled cache")
|
|
||||||
runtime_shape, graph_index = key
|
|
||||||
return self.cache[runtime_shape][graph_index]
|
|
||||||
|
|
||||||
def __setitem__(self, key: Tuple[Optional[int], int],
|
|
||||||
value: InductorArtifact):
|
|
||||||
# setitem for disabled cache is fine, because we
|
|
||||||
# don't actually write to the disk
|
|
||||||
runtime_shape, graph_index = key
|
|
||||||
self.cache[runtime_shape][graph_index] = value
|
|
||||||
|
|
||||||
|
|
||||||
class AlwaysHitShapeEnv:
|
|
||||||
"""
|
|
||||||
Why do we need this class:
|
|
||||||
|
|
||||||
For normal `torch.compile` usage, every compilation will have
|
|
||||||
one Dynamo bytecode compilation and one Inductor compilation.
|
|
||||||
The Inductor compilation happens under the context of the
|
|
||||||
Dynamo bytecode compilation, and that context is used to
|
|
||||||
determine the dynamic shape information, etc.
|
|
||||||
|
|
||||||
For our use case, we only run Dynamo bytecode compilation once,
|
|
||||||
and run Inductor compilation multiple times with different shapes
|
|
||||||
plus a general shape. The compilation for specific shapes happens
|
|
||||||
outside of the context of the Dynamo bytecode compilation. At that
|
|
||||||
time, we don't have shape environment to provide to Inductor, and
|
|
||||||
it will fail the Inductor code cache lookup.
|
|
||||||
|
|
||||||
By providing a dummy shape environment that always hits, we can
|
|
||||||
make the Inductor code cache lookup always hit, and we can
|
|
||||||
compile the graph for different shapes as needed.
|
|
||||||
|
|
||||||
The following dummy methods are obtained by trial-and-error
|
|
||||||
until it works.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.guards: List[Any] = []
|
|
||||||
|
|
||||||
def evaluate_guards_expression(self, *args, **kwargs):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_pruned_guards(self, *args, **kwargs):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def produce_guards_expression(self, *args, **kwargs):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_inductor(graph: fx.GraphModule,
|
|
||||||
example_inputs,
|
example_inputs,
|
||||||
additional_inductor_config,
|
additional_inductor_config,
|
||||||
compilation_config: CompilationConfig,
|
compilation_config: CompilationConfig,
|
||||||
vllm_backend: "VllmBackend",
|
|
||||||
graph_index: int = 0,
|
graph_index: int = 0,
|
||||||
num_graphs: int = 1,
|
num_graphs: int = 1,
|
||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None) -> Any:
|
||||||
use_inductor: bool = True) -> Any:
|
|
||||||
if graph_index == 0:
|
if graph_index == 0:
|
||||||
# before compiling the first graph, record the start time
|
# before compiling the first graph, record the start time
|
||||||
global compilation_start_time
|
global compilation_start_time
|
||||||
compilation_start_time = time.time()
|
compilation_start_time = time.time()
|
||||||
|
|
||||||
if not use_inductor:
|
compilation_counter.num_backend_compilations += 1
|
||||||
return graph
|
|
||||||
|
|
||||||
compilation_counter.num_inductor_compilations += 1
|
compiled_graph = None
|
||||||
|
|
||||||
from torch._inductor import config
|
# try to load from the cache
|
||||||
current_config = config.get_config_copy()
|
compiled_graph = self.load(graph, example_inputs, graph_index,
|
||||||
from torch._inductor.compile_fx import compile_fx
|
runtime_shape)
|
||||||
|
if compiled_graph is not None:
|
||||||
if additional_inductor_config is not None:
|
|
||||||
current_config.update(additional_inductor_config)
|
|
||||||
|
|
||||||
if isinstance(runtime_shape, int):
|
|
||||||
# for a specific batchsize, tuning triton kernel parameters
|
|
||||||
# can be beneficial
|
|
||||||
current_config["max_autotune"] = True
|
|
||||||
current_config["coordinate_descent_tuning"] = True
|
|
||||||
|
|
||||||
# inductor can inplace modify the graph, so we need to copy it
|
|
||||||
# see https://github.com/pytorch/pytorch/issues/138980
|
|
||||||
graph = copy.deepcopy(graph)
|
|
||||||
|
|
||||||
cache_data = vllm_backend.inductor_hash_cache
|
|
||||||
if (runtime_shape, graph_index) in cache_data:
|
|
||||||
# we compiled this graph before
|
|
||||||
# so we can directly lookup the compiled graph via hash
|
|
||||||
inductor_artifact = cache_data[(runtime_shape, graph_index)]
|
|
||||||
hash_str = inductor_artifact.hash_str
|
|
||||||
if graph_index == 0:
|
if graph_index == 0:
|
||||||
# adds some info logging for the first graph
|
# adds some info logging for the first graph
|
||||||
logger.info(
|
logger.info("Directly load the compiled graph for shape %s "
|
||||||
"Directly lookup the graph for shape %s from the cache",
|
"from the cache", str(runtime_shape)) # noqa
|
||||||
str(runtime_shape)) # noqa
|
return compiled_graph
|
||||||
logger.debug(
|
|
||||||
"directly lookup the %s-th graph for shape %s via hash %s",
|
|
||||||
graph_index, str(runtime_shape), hash_str)
|
|
||||||
from torch._inductor.codecache import FxGraphCache
|
|
||||||
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
|
||||||
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
|
||||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
|
||||||
hash_str, example_inputs, True, False)
|
|
||||||
assert inductor_compiled_graph is not None, (
|
|
||||||
"Inductor cache lookup failed. Please remove"
|
|
||||||
f"the cache file {cache_data.cache_file_path} and try again." # noqa
|
|
||||||
)
|
|
||||||
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
|
||||||
|
|
||||||
# Inductor calling convention (function signature):
|
# no compiler cached the graph, or the cache is disabled,
|
||||||
# f(list) -> tuple
|
# we need to compile it
|
||||||
# Dynamo calling convention (function signature):
|
compiled_graph, handle = self.compiler.compile(
|
||||||
# f(*args) -> Any
|
graph, example_inputs, additional_inductor_config, runtime_shape)
|
||||||
|
|
||||||
# need to know if the graph returns a tuple
|
assert compiled_graph is not None, "Failed to compile the graph"
|
||||||
from torch._inductor.compile_fx import graph_returns_tuple
|
|
||||||
returns_tuple = graph_returns_tuple(graph)
|
|
||||||
|
|
||||||
# this is the callable we return to Dynamo to run
|
# store the artifact in the cache
|
||||||
def compiled_graph(*args):
|
if handle is not None:
|
||||||
# convert args to list
|
self.cache[(runtime_shape, graph_index,
|
||||||
list_args = list(args)
|
self.compiler.name)] = handle
|
||||||
graph_output = inductor_compiled_graph(list_args)
|
|
||||||
# unpack the tuple if needed
|
|
||||||
if returns_tuple:
|
|
||||||
return graph_output
|
|
||||||
else:
|
|
||||||
return graph_output[0]
|
|
||||||
else:
|
|
||||||
# it's the first time we compile this graph
|
|
||||||
# the assumption is that we don't have nested Inductor compilation.
|
|
||||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
|
||||||
# it to get the hash of the compiled graph directly.
|
|
||||||
|
|
||||||
inductor_artifact = InductorArtifact()
|
|
||||||
from torch._inductor.codecache import (FxGraphCache,
|
|
||||||
compiled_fx_graph_hash)
|
|
||||||
original_load = FxGraphCache.load
|
|
||||||
|
|
||||||
def hijack_load(*args, **kwargs):
|
|
||||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
|
||||||
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
|
||||||
return inductor_compiled_graph
|
|
||||||
|
|
||||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
|
||||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
|
||||||
inductor_artifact.hash_str = out[0]
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _check_can_cache(*args, **kwargs):
|
|
||||||
# no error means it can be cached.
|
|
||||||
# Inductor refuses to cache the graph outside of Dynamo
|
|
||||||
# tracing context, and also disables caching for graphs
|
|
||||||
# with high-order ops.
|
|
||||||
# For vLLM, in either case, we want to cache the graph.
|
|
||||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
|
||||||
return
|
|
||||||
|
|
||||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
|
||||||
return AlwaysHitShapeEnv()
|
|
||||||
|
|
||||||
with ExitStack() as stack:
|
|
||||||
if not cache_data.disabled:
|
|
||||||
# compilation cache is enabled, patch several functions
|
|
||||||
|
|
||||||
# hijack to get the compiled graph itself
|
|
||||||
stack.enter_context(
|
|
||||||
patch("torch._inductor.codecache.FxGraphCache.load",
|
|
||||||
hijack_load))
|
|
||||||
|
|
||||||
# for hijacking the hash of the compiled graph
|
|
||||||
stack.enter_context(
|
|
||||||
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
|
||||||
hijack_compiled_fx_graph_hash))
|
|
||||||
|
|
||||||
# for providing a dummy shape environment
|
|
||||||
stack.enter_context(
|
|
||||||
patch(
|
|
||||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
|
||||||
_get_shape_env))
|
|
||||||
|
|
||||||
# for forcing the graph to be cached
|
|
||||||
stack.enter_context(
|
|
||||||
patch(
|
|
||||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
|
||||||
_check_can_cache))
|
|
||||||
|
|
||||||
compiled_graph = compile_fx(graph,
|
|
||||||
example_inputs,
|
|
||||||
config_patches=current_config)
|
|
||||||
# store the inductor_artifact in the cache
|
|
||||||
cache_data[(runtime_shape, graph_index)] = inductor_artifact
|
|
||||||
if graph_index == 0:
|
if graph_index == 0:
|
||||||
# adds some info logging for the first graph
|
# adds some info logging for the first graph
|
||||||
logger.info("Cache the graph of shape %s for later use",
|
logger.info("Cache the graph of shape %s for later use",
|
||||||
str(runtime_shape))
|
str(runtime_shape))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"store the %s-th graph for shape %s via hash %s from file %s",
|
"store the %s-th graph for shape %s from %s via handle %s",
|
||||||
graph_index, str(runtime_shape), inductor_artifact.hash_str,
|
graph_index, str(runtime_shape), self.compiler.name, handle)
|
||||||
inductor_artifact.file_path)
|
|
||||||
# after compiling the last graph, record the end time
|
# after compiling the last graph, record the end time
|
||||||
if graph_index == num_graphs - 1:
|
if graph_index == num_graphs - 1:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||||
]
|
]
|
||||||
global compilation_start_time
|
global compilation_start_time
|
||||||
compiled_graph_for_general_shape = wrap_inductor(
|
compiled_graph_for_general_shape = self.vllm_backend.\
|
||||||
|
compiler_manager.compile(
|
||||||
submod,
|
submod,
|
||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.compilation_config.inductor_compile_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
self.vllm_backend,
|
|
||||||
graph_index=index,
|
graph_index=index,
|
||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
runtime_shape=None,
|
runtime_shape=None)
|
||||||
use_inductor=self.compilation_config.use_inductor)
|
|
||||||
|
|
||||||
self.module.__dict__[target] = PiecewiseBackend(
|
self.module.__dict__[target] = PiecewiseBackend(
|
||||||
submod, self.vllm_config, self.graph_pool, index,
|
submod, self.vllm_config, self.graph_pool, index,
|
||||||
@ -483,7 +303,7 @@ class VllmBackend:
|
|||||||
post_grad_passes: Sequence[Callable]
|
post_grad_passes: Sequence[Callable]
|
||||||
sym_tensor_indices: List[int]
|
sym_tensor_indices: List[int]
|
||||||
input_buffers: List[torch.Tensor]
|
input_buffers: List[torch.Tensor]
|
||||||
inductor_hash_cache: InductorHashCache
|
compiler_manager: CompilerManager
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -507,6 +327,9 @@ class VllmBackend:
|
|||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
|
self.compiler_manager: CompilerManager = CompilerManager(
|
||||||
|
self.compilation_config.use_inductor)
|
||||||
|
|
||||||
# `torch.compile` is JIT compiled, so we don't need to
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
# do anything here
|
# do anything here
|
||||||
|
|
||||||
@ -533,9 +356,11 @@ class VllmBackend:
|
|||||||
# the cache dir will be the same so that we can reuse the compiled
|
# the cache dir will be the same so that we can reuse the compiled
|
||||||
# graph.
|
# graph.
|
||||||
|
|
||||||
|
factors = []
|
||||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||||
# model is created)
|
# model is created)
|
||||||
config_hash = vllm_config.compute_hash()
|
config_hash = vllm_config.compute_hash()
|
||||||
|
factors.append(config_hash)
|
||||||
|
|
||||||
# 2. factors come from the code files that are traced by Dynamo (
|
# 2. factors come from the code files that are traced by Dynamo (
|
||||||
# it mainly summarizes how the model is used in forward pass)
|
# it mainly summarizes how the model is used in forward pass)
|
||||||
@ -553,10 +378,15 @@ class VllmBackend:
|
|||||||
import hashlib
|
import hashlib
|
||||||
code_hash = hashlib.md5(
|
code_hash = hashlib.md5(
|
||||||
"\n".join(hash_content).encode()).hexdigest()
|
"\n".join(hash_content).encode()).hexdigest()
|
||||||
|
factors.append(code_hash)
|
||||||
|
|
||||||
|
# 3. compiler hash
|
||||||
|
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||||
|
factors.append(compiler_hash)
|
||||||
|
|
||||||
|
# combine all factors to generate the cache dir
|
||||||
|
hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||||
|
|
||||||
# combine the two hashes to generate the cache dir
|
|
||||||
hash_key = hashlib.md5(
|
|
||||||
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
|
|
||||||
cache_dir = os.path.join(
|
cache_dir = os.path.join(
|
||||||
envs.VLLM_CACHE_ROOT,
|
envs.VLLM_CACHE_ROOT,
|
||||||
"torch_compile_cache",
|
"torch_compile_cache",
|
||||||
@ -570,15 +400,16 @@ class VllmBackend:
|
|||||||
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
|
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
|
||||||
self.compilation_config.local_cache_dir = local_cache_dir
|
self.compilation_config.local_cache_dir = local_cache_dir
|
||||||
|
|
||||||
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||||
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
|
||||||
local_cache_dir, disabled=disabled)
|
if disable_cache:
|
||||||
if disabled:
|
|
||||||
logger.info("vLLM's torch.compile cache is disabled.")
|
logger.info("vLLM's torch.compile cache is disabled.")
|
||||||
else:
|
else:
|
||||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||||
local_cache_dir)
|
local_cache_dir)
|
||||||
|
|
||||||
|
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
|
||||||
|
|
||||||
# when dynamo calls the backend, it means the bytecode
|
# when dynamo calls the backend, it means the bytecode
|
||||||
# transform and analysis are done
|
# transform and analysis are done
|
||||||
compilation_counter.num_graphs_seen += 1
|
compilation_counter.num_graphs_seen += 1
|
||||||
@ -759,7 +590,7 @@ class PiecewiseBackend:
|
|||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
# no specific sizes to compile
|
# no specific sizes to compile
|
||||||
# save the hash of the inductor graph for the next run
|
# save the hash of the inductor graph for the next run
|
||||||
self.vllm_backend.inductor_hash_cache.save_to_file()
|
self.vllm_backend.compiler_manager.save_to_file()
|
||||||
end_monitoring_torch_compile(self.vllm_config)
|
end_monitoring_torch_compile(self.vllm_config)
|
||||||
|
|
||||||
def __call__(self, *args) -> Any:
|
def __call__(self, *args) -> Any:
|
||||||
@ -782,16 +613,14 @@ class PiecewiseBackend:
|
|||||||
entry.compiled = True
|
entry.compiled = True
|
||||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||||
# args are real arguments
|
# args are real arguments
|
||||||
entry.runnable = wrap_inductor(
|
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||||
self.graph,
|
self.graph,
|
||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.compilation_config.inductor_compile_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
self.vllm_backend,
|
|
||||||
graph_index=self.piecewise_compile_index,
|
graph_index=self.piecewise_compile_index,
|
||||||
num_graphs=self.total_piecewise_compiles,
|
num_graphs=self.total_piecewise_compiles,
|
||||||
runtime_shape=runtime_shape,
|
runtime_shape=runtime_shape)
|
||||||
use_inductor=self.compilation_config.use_inductor)
|
|
||||||
|
|
||||||
# finished compilations for all required shapes
|
# finished compilations for all required shapes
|
||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
|
|||||||
340
vllm/compilation/compiler_interface.py
Normal file
340
vllm/compilation/compiler_interface.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import copy
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._inductor.compile_fx
|
||||||
|
import torch.fx as fx
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CompilerInterface:
|
||||||
|
"""
|
||||||
|
The interface for a compiler that can be used by vLLM.
|
||||||
|
"""
|
||||||
|
# The name of the compiler, e.g. inductor.
|
||||||
|
# This is a class-level attribute.
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||||
|
"""
|
||||||
|
when the vLLM process uses `cache_dir` as the cache directory,
|
||||||
|
the compiler should initialize itself with the cache directory,
|
||||||
|
e.g. by re-directing its own cache directory to a sub-directory.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
|
"""
|
||||||
|
Gather all the relevant information from the VLLM config,
|
||||||
|
to compute a hash so that we can cache the compiled model.
|
||||||
|
|
||||||
|
See :meth:`VllmConfig.compute_hash` to check what information
|
||||||
|
is already considered by default. This function should only
|
||||||
|
consider the information that is specific to the compiler.
|
||||||
|
"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: List[Any],
|
||||||
|
compiler_config: Dict[str, Any],
|
||||||
|
runtime_shape: Optional[int] = None
|
||||||
|
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
"""
|
||||||
|
Compile the graph with the given example inputs and compiler config,
|
||||||
|
with a runtime shape. If the `runtime_shape` is None, it means
|
||||||
|
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||||
|
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||||
|
support one variable shape for all inputs, which is the batchsize
|
||||||
|
(number of tokens) during inference.
|
||||||
|
|
||||||
|
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||||
|
|
||||||
|
The function should return a compiled callable function, as well as
|
||||||
|
a handle that can be used to directly load the compiled function.
|
||||||
|
|
||||||
|
The handle should be a plain Python object, preferably a string or a
|
||||||
|
file path for readability.
|
||||||
|
|
||||||
|
If the compiler doesn't support caching, it should return None for the
|
||||||
|
handle. If the compiler fails to compile the graph, it should return
|
||||||
|
None for the compiled function as well.
|
||||||
|
"""
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def load(self,
|
||||||
|
handle: Any,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: List[Any],
|
||||||
|
graph_index: int,
|
||||||
|
runtime_shape: Optional[int] = None) -> Callable:
|
||||||
|
"""
|
||||||
|
Load the compiled function from the handle.
|
||||||
|
Raises an error if the handle is invalid.
|
||||||
|
|
||||||
|
The handle is the second return value of the `compile` function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("caching is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
class AlwaysHitShapeEnv:
|
||||||
|
"""
|
||||||
|
Why do we need this class:
|
||||||
|
|
||||||
|
For normal `torch.compile` usage, every compilation will have
|
||||||
|
one Dynamo bytecode compilation and one Inductor compilation.
|
||||||
|
The Inductor compilation happens under the context of the
|
||||||
|
Dynamo bytecode compilation, and that context is used to
|
||||||
|
determine the dynamic shape information, etc.
|
||||||
|
|
||||||
|
For our use case, we only run Dynamo bytecode compilation once,
|
||||||
|
and run Inductor compilation multiple times with different shapes
|
||||||
|
plus a general shape. The compilation for specific shapes happens
|
||||||
|
outside of the context of the Dynamo bytecode compilation. At that
|
||||||
|
time, we don't have shape environment to provide to Inductor, and
|
||||||
|
it will fail the Inductor code cache lookup.
|
||||||
|
|
||||||
|
By providing a dummy shape environment that always hits, we can
|
||||||
|
make the Inductor code cache lookup always hit, and we can
|
||||||
|
compile the graph for different shapes as needed.
|
||||||
|
|
||||||
|
The following dummy methods are obtained by trial-and-error
|
||||||
|
until it works.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.guards: List[Any] = []
|
||||||
|
|
||||||
|
def evaluate_guards_expression(self, *args, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_pruned_guards(self, *args, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def produce_guards_expression(self, *args, **kwargs):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class InductorAdaptor(CompilerInterface):
|
||||||
|
"""
|
||||||
|
The adaptor for the Inductor compiler, version 2.5 and 2.6.
|
||||||
|
"""
|
||||||
|
name = "inductor"
|
||||||
|
|
||||||
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
|
factors: List[Any] = []
|
||||||
|
# summarize system state
|
||||||
|
from torch._inductor.codecache import CacheBase
|
||||||
|
system_factors = CacheBase.get_system()
|
||||||
|
factors.append(system_factors)
|
||||||
|
|
||||||
|
# summarize pytorch state
|
||||||
|
from torch._inductor.codecache import torch_key
|
||||||
|
torch_factors = torch_key()
|
||||||
|
factors.append(torch_factors)
|
||||||
|
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||||
|
if disable_cache:
|
||||||
|
return
|
||||||
|
# redirect the cache directory to a sub-directory
|
||||||
|
# set flags so that Inductor and Triton store their cache
|
||||||
|
# in the cache_dir, then users only need to copy the cache_dir
|
||||||
|
# to another machine to reuse the cache.
|
||||||
|
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
||||||
|
os.makedirs(inductor_cache, exist_ok=True)
|
||||||
|
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||||
|
triton_cache = os.path.join(cache_dir, "triton_cache")
|
||||||
|
os.makedirs(triton_cache, exist_ok=True)
|
||||||
|
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: List[Any],
|
||||||
|
compiler_config: Dict[str, Any],
|
||||||
|
runtime_shape: Optional[int] = None
|
||||||
|
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
from torch._inductor import config
|
||||||
|
current_config = config.get_config_copy()
|
||||||
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
|
||||||
|
# disable remote cache
|
||||||
|
current_config["fx_graph_cache"] = True
|
||||||
|
current_config["fx_graph_remote_cache"] = False
|
||||||
|
|
||||||
|
if compiler_config is not None:
|
||||||
|
current_config.update(compiler_config)
|
||||||
|
|
||||||
|
if isinstance(runtime_shape, int):
|
||||||
|
# for a specific batchsize, tuning triton kernel parameters
|
||||||
|
# can be beneficial
|
||||||
|
current_config["max_autotune"] = True
|
||||||
|
current_config["coordinate_descent_tuning"] = True
|
||||||
|
|
||||||
|
# inductor can inplace modify the graph, so we need to copy it
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/138980
|
||||||
|
graph = copy.deepcopy(graph)
|
||||||
|
|
||||||
|
# it's the first time we compile this graph
|
||||||
|
# the assumption is that we don't have nested Inductor compilation.
|
||||||
|
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||||
|
# it to get the hash of the compiled graph directly.
|
||||||
|
|
||||||
|
hash_str, file_path = None, None
|
||||||
|
from torch._inductor.codecache import (FxGraphCache,
|
||||||
|
compiled_fx_graph_hash)
|
||||||
|
|
||||||
|
if torch.__version__.startswith("2.5"):
|
||||||
|
original_load = FxGraphCache.load
|
||||||
|
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||||
|
|
||||||
|
def hijack_load(*args, **kwargs):
|
||||||
|
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||||
|
nonlocal file_path
|
||||||
|
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||||
|
return inductor_compiled_graph
|
||||||
|
|
||||||
|
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
|
||||||
|
elif torch.__version__ >= "2.6":
|
||||||
|
# function renamed in 2.6
|
||||||
|
original_load_name = None
|
||||||
|
|
||||||
|
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||||
|
output = torch._inductor.compile_fx.compile_fx_inner(
|
||||||
|
*args, **kwargs)
|
||||||
|
nonlocal hash_str
|
||||||
|
inductor_compiled_graph = output
|
||||||
|
if inductor_compiled_graph is not None:
|
||||||
|
nonlocal file_path
|
||||||
|
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||||
|
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||||
|
return output
|
||||||
|
|
||||||
|
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||||
|
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||||
|
nonlocal hash_str
|
||||||
|
hash_str = out[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _check_can_cache(*args, **kwargs):
|
||||||
|
# no error means it can be cached.
|
||||||
|
# Inductor refuses to cache the graph outside of Dynamo
|
||||||
|
# tracing context, and also disables caching for graphs
|
||||||
|
# with high-order ops.
|
||||||
|
# For vLLM, in either case, we want to cache the graph.
|
||||||
|
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||||
|
return AlwaysHitShapeEnv()
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
# hijack to get the compiled graph itself
|
||||||
|
if original_load_name is not None:
|
||||||
|
stack.enter_context(patch(original_load_name, hijack_load))
|
||||||
|
|
||||||
|
# for hijacking the hash of the compiled graph
|
||||||
|
stack.enter_context(
|
||||||
|
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
||||||
|
hijack_compiled_fx_graph_hash))
|
||||||
|
|
||||||
|
# for providing a dummy shape environment
|
||||||
|
stack.enter_context(
|
||||||
|
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||||
|
_get_shape_env))
|
||||||
|
|
||||||
|
# for forcing the graph to be cached
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||||
|
_check_can_cache))
|
||||||
|
|
||||||
|
compiled_graph = compile_fx(
|
||||||
|
graph,
|
||||||
|
example_inputs,
|
||||||
|
inner_compile=hijacked_compile_fx_inner,
|
||||||
|
config_patches=current_config)
|
||||||
|
|
||||||
|
assert hash_str is not None, (
|
||||||
|
"failed to get the hash of the compiled graph")
|
||||||
|
assert file_path is not None, (
|
||||||
|
"failed to get the file path of the compiled graph")
|
||||||
|
return compiled_graph, (hash_str, file_path)
|
||||||
|
|
||||||
|
def load(self,
|
||||||
|
handle: Any,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: List[Any],
|
||||||
|
graph_index: int,
|
||||||
|
runtime_shape: Optional[int] = None) -> Callable:
|
||||||
|
assert isinstance(handle, tuple)
|
||||||
|
assert isinstance(handle[0], str)
|
||||||
|
assert isinstance(handle[1], str)
|
||||||
|
hash_str = handle[0]
|
||||||
|
|
||||||
|
from torch._inductor.codecache import FxGraphCache
|
||||||
|
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||||
|
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
||||||
|
if torch.__version__.startswith("2.5"):
|
||||||
|
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||||
|
hash_str, example_inputs, True, False)
|
||||||
|
assert inductor_compiled_graph is not None, (
|
||||||
|
"Inductor cache lookup failed. Please remove"
|
||||||
|
f"the cache directory and try again." # noqa
|
||||||
|
)
|
||||||
|
elif torch.__version__ >= "2.6":
|
||||||
|
from torch._inductor.output_code import (
|
||||||
|
CompiledFxGraphConstantsWithGm)
|
||||||
|
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||||
|
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||||
|
hash_str, example_inputs, True, None, constants)
|
||||||
|
assert inductor_compiled_graph is not None, (
|
||||||
|
"Inductor cache lookup failed. Please remove"
|
||||||
|
f"the cache directory and try again." # noqa
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inductor calling convention (function signature):
|
||||||
|
# f(list) -> tuple
|
||||||
|
# Dynamo calling convention (function signature):
|
||||||
|
# f(*args) -> Any
|
||||||
|
|
||||||
|
# need to know if the graph returns a tuple
|
||||||
|
from torch._inductor.compile_fx import graph_returns_tuple
|
||||||
|
returns_tuple = graph_returns_tuple(graph)
|
||||||
|
|
||||||
|
# this is the callable we return to Dynamo to run
|
||||||
|
def compiled_graph(*args):
|
||||||
|
# convert args to list
|
||||||
|
list_args = list(args)
|
||||||
|
graph_output = inductor_compiled_graph(list_args)
|
||||||
|
# unpack the tuple if needed
|
||||||
|
if returns_tuple:
|
||||||
|
return graph_output
|
||||||
|
else:
|
||||||
|
return graph_output[0]
|
||||||
|
|
||||||
|
return compiled_graph
|
||||||
|
|
||||||
|
|
||||||
|
class EagerAdaptor(CompilerInterface):
|
||||||
|
name = "eager"
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: List[Any],
|
||||||
|
compiler_config: Dict[str, Any],
|
||||||
|
runtime_shape: Optional[int] = None
|
||||||
|
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
# we don't need to compile the graph, just return the graph itself.
|
||||||
|
# It does not support caching, return None for the handle.
|
||||||
|
return graph, None
|
||||||
@ -13,7 +13,7 @@ class CompilationCounter:
|
|||||||
num_piecewise_graphs_seen: int = 0
|
num_piecewise_graphs_seen: int = 0
|
||||||
# not including the splitting ops
|
# not including the splitting ops
|
||||||
num_piecewise_capturable_graphs_seen: int = 0
|
num_piecewise_capturable_graphs_seen: int = 0
|
||||||
num_inductor_compilations: int = 0
|
num_backend_compilations: int = 0
|
||||||
num_cudagraph_caputured: int = 0
|
num_cudagraph_caputured: int = 0
|
||||||
|
|
||||||
def clone(self) -> "CompilationCounter":
|
def clone(self) -> "CompilationCounter":
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from torch import fx
|
|||||||
class InductorPass(ABC):
|
class InductorPass(ABC):
|
||||||
"""
|
"""
|
||||||
General custom inductor pass interface.
|
General custom inductor pass interface.
|
||||||
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
|
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PostGradPassManager:
|
class PlaceHolder:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if torch.__version__ < "2.6":
|
||||||
|
Parent = PlaceHolder # type: ignore
|
||||||
|
else:
|
||||||
|
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class PostGradPassManager(Parent):
|
||||||
"""
|
"""
|
||||||
The pass manager for post-grad passes.
|
The pass manager for post-grad passes.
|
||||||
It handles configuration, adding custom passes, and running passes.
|
It handles configuration, adding custom passes, and running passes.
|
||||||
@ -55,6 +66,9 @@ class PostGradPassManager:
|
|||||||
assert isinstance(pass_, InductorPass)
|
assert isinstance(pass_, InductorPass)
|
||||||
self.passes.append(pass_)
|
self.passes.append(pass_)
|
||||||
|
|
||||||
|
def uuid(self):
|
||||||
|
return self.__getstate__()
|
||||||
|
|
||||||
def __getstate__(self) -> Dict[str, List[Any]]:
|
def __getstate__(self) -> Dict[str, List[Any]]:
|
||||||
"""
|
"""
|
||||||
Custom pickling for the pass manager, as some passes cannot be pickled.
|
Custom pickling for the pass manager, as some passes cannot be pickled.
|
||||||
|
|||||||
@ -3072,15 +3072,6 @@ class VllmConfig:
|
|||||||
the final hidden states.
|
the final hidden states.
|
||||||
"""
|
"""
|
||||||
factors: List[Any] = []
|
factors: List[Any] = []
|
||||||
# summarize system state
|
|
||||||
from torch._inductor.codecache import CacheBase
|
|
||||||
system_factors = CacheBase.get_system()
|
|
||||||
factors.append(system_factors)
|
|
||||||
|
|
||||||
# summarize pytorch state
|
|
||||||
from torch._inductor.codecache import torch_key
|
|
||||||
torch_factors = torch_key()
|
|
||||||
factors.append(torch_factors)
|
|
||||||
|
|
||||||
# summarize vllm config
|
# summarize vllm config
|
||||||
vllm_factors: List[Any] = []
|
vllm_factors: List[Any] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user