mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 15:27:07 +08:00
AOT Compilation for torch.compile (Bundled) (#24274)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
parent
e317414ce1
commit
eef921f45e
@ -403,6 +403,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||||
- pytest -v -s compile/test_decorator.py
|
- pytest -v -s compile/test_decorator.py
|
||||||
- pytest -v -s compile/test_noop_elimination.py
|
- pytest -v -s compile/test_noop_elimination.py
|
||||||
|
- pytest -v -s compile/test_aot_compile.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
|
|||||||
139
tests/compile/test_aot_compile.py
Normal file
139
tests/compile/test_aot_compile.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import (
|
||||||
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
|
def reference_fn(x: torch.Tensor):
|
||||||
|
assert x.shape[0] <= 42
|
||||||
|
assert x.shape[0] % 2 == 0
|
||||||
|
for _ in range(3000):
|
||||||
|
x = x + x.shape[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class CompiledMod(torch.nn.Module):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return reference_fn(x)
|
||||||
|
|
||||||
|
|
||||||
|
def make_vllm_config() -> VllmConfig:
|
||||||
|
return VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_vllm_config(vllm_config: VllmConfig):
|
||||||
|
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||||
|
)
|
||||||
|
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
vllm_config = make_vllm_config()
|
||||||
|
args = (torch.randn(10, 10),)
|
||||||
|
expected = reference_fn(*args)
|
||||||
|
with use_vllm_config(vllm_config):
|
||||||
|
m.setenv("VLLM_USE_AOT_COMPILE", "0")
|
||||||
|
with (
|
||||||
|
pytest.raises(RuntimeError, match="Detected recompile"),
|
||||||
|
torch.compiler.set_stance("fail_on_recompile"),
|
||||||
|
):
|
||||||
|
CompiledMod(vllm_config=vllm_config)(*args)
|
||||||
|
|
||||||
|
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||||
|
torch._dynamo.reset()
|
||||||
|
with torch.compiler.set_stance("fail_on_recompile"):
|
||||||
|
actual = CompiledMod(vllm_config=vllm_config)(*args)
|
||||||
|
assert torch.allclose(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||||
|
)
|
||||||
|
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
|
||||||
|
args = (torch.randn(10, 10),)
|
||||||
|
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||||
|
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||||
|
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||||
|
vllm_config = make_vllm_config()
|
||||||
|
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
|
||||||
|
CompiledMod(vllm_config=vllm_config)(*args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||||
|
)
|
||||||
|
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
args = (torch.randn(10, 10),)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||||
|
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||||
|
vllm_config = make_vllm_config()
|
||||||
|
with use_vllm_config(vllm_config):
|
||||||
|
expected = CompiledMod(vllm_config=vllm_config)(*args)
|
||||||
|
|
||||||
|
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||||
|
vllm_config = make_vllm_config()
|
||||||
|
with use_vllm_config(vllm_config):
|
||||||
|
ret = CompiledMod(vllm_config=vllm_config)(*args)
|
||||||
|
assert torch.allclose(ret, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||||
|
)
|
||||||
|
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""
|
||||||
|
Test that the shape environment is correctly serialized and preserved
|
||||||
|
when loading from cache.
|
||||||
|
"""
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
args = (torch.randn(10, 10),)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||||
|
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||||
|
vllm_config = make_vllm_config()
|
||||||
|
with use_vllm_config(vllm_config):
|
||||||
|
compiled_mod = CompiledMod(vllm_config=vllm_config)
|
||||||
|
compiled_mod(*args)
|
||||||
|
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||||
|
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||||
|
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||||
|
|
||||||
|
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||||
|
vllm_config = make_vllm_config()
|
||||||
|
with use_vllm_config(vllm_config):
|
||||||
|
compiled_mod = CompiledMod(vllm_config=vllm_config)
|
||||||
|
compiled_mod(*args)
|
||||||
|
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||||
|
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||||
|
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||||
@ -22,6 +22,7 @@ ALLOWED_FILES = {
|
|||||||
"vllm/multimodal/hasher.py",
|
"vllm/multimodal/hasher.py",
|
||||||
"vllm/transformers_utils/config.py",
|
"vllm/transformers_utils/config.py",
|
||||||
"vllm/model_executor/models/registry.py",
|
"vllm/model_executor/models/registry.py",
|
||||||
|
"vllm/compilation/caching.py",
|
||||||
"tests/utils_/test_utils.py",
|
"tests/utils_/test_utils.py",
|
||||||
"tests/tokenization/test_cached_tokenizer.py",
|
"tests/tokenization/test_cached_tokenizer.py",
|
||||||
"vllm/distributed/utils.py",
|
"vllm/distributed/utils.py",
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import ast
|
import ast
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import time
|
import time
|
||||||
@ -25,6 +26,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||||
|
|
||||||
|
from .caching import VllmSerializableFunction
|
||||||
from .compiler_interface import (
|
from .compiler_interface import (
|
||||||
CompilerInterface,
|
CompilerInterface,
|
||||||
EagerAdaptor,
|
EagerAdaptor,
|
||||||
@ -195,6 +197,7 @@ class CompilerManager:
|
|||||||
# there can be multiple graphs due to piecewise compilation.
|
# there can be multiple graphs due to piecewise compilation.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
elapsed = now - compilation_start_time
|
elapsed = now - compilation_start_time
|
||||||
|
compilation_config.compilation_time += elapsed
|
||||||
if runtime_shape is None:
|
if runtime_shape is None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Directly load the compiled graph(s) for dynamic shape "
|
"Directly load the compiled graph(s) for dynamic shape "
|
||||||
@ -549,7 +552,11 @@ class VllmBackend:
|
|||||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||||
|
|
||||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
def __call__(
|
||||||
|
self, graph: fx.GraphModule, example_inputs
|
||||||
|
) -> VllmSerializableFunction:
|
||||||
|
from .caching import _compute_code_hash, compilation_config_hash_factors
|
||||||
|
|
||||||
vllm_config = self.vllm_config
|
vllm_config = self.vllm_config
|
||||||
if not self.compilation_config.cache_dir:
|
if not self.compilation_config.cache_dir:
|
||||||
# no provided cache dir, generate one based on the known factors
|
# no provided cache dir, generate one based on the known factors
|
||||||
@ -557,39 +564,11 @@ class VllmBackend:
|
|||||||
# the cache dir will be the same so that we can reuse the compiled
|
# the cache dir will be the same so that we can reuse the compiled
|
||||||
# graph.
|
# graph.
|
||||||
|
|
||||||
factors = []
|
factors = compilation_config_hash_factors(vllm_config)
|
||||||
# 0. factors come from the env, for example, The values of
|
|
||||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
|
||||||
env_hash = envs.compute_hash()
|
|
||||||
factors.append(env_hash)
|
|
||||||
|
|
||||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
|
||||||
# model is created)
|
|
||||||
config_hash = vllm_config.compute_hash()
|
|
||||||
factors.append(config_hash)
|
|
||||||
|
|
||||||
# 2. factors come from the code files that are traced by Dynamo (
|
# 2. factors come from the code files that are traced by Dynamo (
|
||||||
# it mainly summarizes how the model is used in forward pass)
|
# it mainly summarizes how the model is used in forward pass)
|
||||||
forward_code_files = list(sorted(self.compilation_config.traced_files))
|
code_hash = _compute_code_hash(self.compilation_config.traced_files)
|
||||||
self.compilation_config.traced_files.clear()
|
self.compilation_config.traced_files.clear()
|
||||||
logger.debug(
|
|
||||||
"Traced files (to be considered for compilation cache):\n%s",
|
|
||||||
"\n".join(forward_code_files),
|
|
||||||
)
|
|
||||||
hash_content = []
|
|
||||||
for filepath in forward_code_files:
|
|
||||||
hash_content.append(filepath)
|
|
||||||
if filepath == "<string>":
|
|
||||||
# This means the function was dynamically generated, with
|
|
||||||
# e.g. exec(). We can't actually check these.
|
|
||||||
continue
|
|
||||||
with open(filepath) as f:
|
|
||||||
hash_content.append(f.read())
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
code_hash = hashlib.md5(
|
|
||||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
|
||||||
).hexdigest()
|
|
||||||
factors.append(code_hash)
|
factors.append(code_hash)
|
||||||
|
|
||||||
# 3. compiler hash
|
# 3. compiler hash
|
||||||
@ -695,7 +674,9 @@ class VllmBackend:
|
|||||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||||
or not self.compilation_config.cudagraph_copy_inputs
|
or not self.compilation_config.cudagraph_copy_inputs
|
||||||
):
|
):
|
||||||
return self.split_gm
|
return VllmSerializableFunction(
|
||||||
|
graph, example_inputs, self.prefix, self.split_gm
|
||||||
|
)
|
||||||
|
|
||||||
# if we need to copy input buffers for cudagraph
|
# if we need to copy input buffers for cudagraph
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
@ -740,4 +721,6 @@ class VllmBackend:
|
|||||||
list_args[index] = static_tensor
|
list_args[index] = static_tensor
|
||||||
return self.split_gm(*list_args)
|
return self.split_gm(*list_args)
|
||||||
|
|
||||||
return copy_and_call
|
return VllmSerializableFunction(
|
||||||
|
graph, example_inputs, self.prefix, copy_and_call
|
||||||
|
)
|
||||||
|
|||||||
176
vllm/compilation/caching.py
Normal file
176
vllm/compilation/caching.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import pickle
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils import _pytree as pytree
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch._dynamo.aot_compile import SerializableCallable
|
||||||
|
except ImportError:
|
||||||
|
SerializableCallable = object
|
||||||
|
|
||||||
|
assert isinstance(SerializableCallable, type)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VllmSerializableFunction(SerializableCallable):
|
||||||
|
"""
|
||||||
|
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||||
|
inputs to the compiled function and return the result.
|
||||||
|
It also implements a serialization interface to support PyTorch's precompile
|
||||||
|
with custom backend, so that we can save and load the compiled function on
|
||||||
|
disk. There's no need to wrap around the compiled function if we don't want
|
||||||
|
to serialize them in particular cases.
|
||||||
|
Right now serialization for the custom backend is done via
|
||||||
|
serializing the Dynamo fx graph plus example inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, graph_module, example_inputs, prefix, optimized_call):
|
||||||
|
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||||
|
self.graph_module = graph_module
|
||||||
|
self.example_inputs = example_inputs
|
||||||
|
self.prefix = prefix
|
||||||
|
self.optimized_call = optimized_call
|
||||||
|
self.shape_env = None
|
||||||
|
sym_input = next(
|
||||||
|
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||||
|
)
|
||||||
|
if sym_input is not None:
|
||||||
|
self.shape_env = sym_input.node.shape_env
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.optimized_call(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize_compile_artifacts(
|
||||||
|
cls, compiled_fn: "VllmSerializableFunction"
|
||||||
|
) -> bytes:
|
||||||
|
import sympy
|
||||||
|
from torch._subclasses import FakeTensorMode
|
||||||
|
from torch.fx._graph_pickler import GraphPickler, Options
|
||||||
|
|
||||||
|
state = compiled_fn.__dict__.copy()
|
||||||
|
state.pop("optimized_call")
|
||||||
|
state.pop("shape_env")
|
||||||
|
for node in state["graph_module"].graph.nodes:
|
||||||
|
node.meta.pop("source_fn_stack", None)
|
||||||
|
node.meta.pop("nn_module_stack", None)
|
||||||
|
|
||||||
|
graph_reducer_override = GraphPickler.reducer_override
|
||||||
|
|
||||||
|
def _graph_reducer_override(self, obj):
|
||||||
|
if (
|
||||||
|
inspect.isclass(obj)
|
||||||
|
and issubclass(obj, sympy.Function)
|
||||||
|
and hasattr(obj, "_torch_unpickler")
|
||||||
|
):
|
||||||
|
return obj._torch_unpickler, (obj._torch_handler_name,)
|
||||||
|
if isinstance(obj, FakeTensorMode):
|
||||||
|
return type(None), ()
|
||||||
|
return graph_reducer_override(self, obj)
|
||||||
|
|
||||||
|
# Mask off tensor inputs since they are large and not needed.
|
||||||
|
state["example_inputs"] = pytree.tree_map_only(
|
||||||
|
torch.Tensor, lambda _: None, state["example_inputs"]
|
||||||
|
)
|
||||||
|
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||||
|
state["graph_module"] = GraphPickler.dumps(
|
||||||
|
state["graph_module"], Options(ops_filter=None)
|
||||||
|
)
|
||||||
|
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||||
|
return pickle.dumps(state)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
|
||||||
|
from torch._guards import TracingContext, tracing
|
||||||
|
from torch._subclasses import FakeTensorMode
|
||||||
|
from torch.fx._graph_pickler import GraphPickler
|
||||||
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
|
|
||||||
|
from vllm.compilation.backends import VllmBackend
|
||||||
|
|
||||||
|
state = pickle.loads(data)
|
||||||
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||||
|
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||||
|
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||||
|
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"])
|
||||||
|
|
||||||
|
def optimized_call(*example_inputs):
|
||||||
|
"""
|
||||||
|
On the first run of the optimized call, we rerun the compiler
|
||||||
|
backend which should result in a cache hit. After the backend
|
||||||
|
call returns, we just do a one-time replacement of the optimized
|
||||||
|
call with the compiled function, so that subsequent calls are on
|
||||||
|
the AOT compiled path.
|
||||||
|
"""
|
||||||
|
compile_inputs = [
|
||||||
|
inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs)
|
||||||
|
]
|
||||||
|
with tracing(TracingContext(fake_mode)):
|
||||||
|
fn.optimized_call = vllm_backend(
|
||||||
|
state["graph_module"], compile_inputs
|
||||||
|
).optimized_call
|
||||||
|
return fn.optimized_call(*example_inputs)
|
||||||
|
|
||||||
|
fn = cls(**state, optimized_call=optimized_call)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def co_name(self):
|
||||||
|
"""
|
||||||
|
Used for depyf debugging.
|
||||||
|
"""
|
||||||
|
return "VllmSerializableFunction"
|
||||||
|
|
||||||
|
|
||||||
|
def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||||
|
factors = []
|
||||||
|
# 0. factors come from the env, for example, The values of
|
||||||
|
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||||
|
env_hash = envs.compute_hash()
|
||||||
|
factors.append(env_hash)
|
||||||
|
|
||||||
|
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||||
|
# model is created)
|
||||||
|
config_hash = vllm_config.compute_hash()
|
||||||
|
factors.append(config_hash)
|
||||||
|
return factors
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||||
|
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
|
||||||
|
hash_content = []
|
||||||
|
for filepath, content in items:
|
||||||
|
hash_content.append(filepath)
|
||||||
|
if filepath == "<string>":
|
||||||
|
# This means the function was dynamically generated, with
|
||||||
|
# e.g. exec(). We can't actually check these.
|
||||||
|
continue
|
||||||
|
hash_content.append(content)
|
||||||
|
return hashlib.md5(
|
||||||
|
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_code_hash(files: set[str]) -> str:
|
||||||
|
logger.debug(
|
||||||
|
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
|
||||||
|
)
|
||||||
|
file_contents = {}
|
||||||
|
for filepath in files:
|
||||||
|
if filepath == "<string>":
|
||||||
|
file_contents[filepath] = ""
|
||||||
|
else:
|
||||||
|
with open(filepath) as f:
|
||||||
|
file_contents[filepath] = f.read()
|
||||||
|
return _compute_code_hash_with_content(file_contents)
|
||||||
@ -199,6 +199,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
if compiler_config is not None:
|
if compiler_config is not None:
|
||||||
current_config.update(compiler_config)
|
current_config.update(compiler_config)
|
||||||
set_inductor_config(current_config, runtime_shape)
|
set_inductor_config(current_config, runtime_shape)
|
||||||
|
set_functorch_config()
|
||||||
|
|
||||||
if isinstance(runtime_shape, int):
|
if isinstance(runtime_shape, int):
|
||||||
dynamic_shapes = "from_example_inputs"
|
dynamic_shapes = "from_example_inputs"
|
||||||
@ -307,6 +308,7 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
current_config["fx_graph_remote_cache"] = False
|
current_config["fx_graph_remote_cache"] = False
|
||||||
|
|
||||||
set_inductor_config(current_config, runtime_shape)
|
set_inductor_config(current_config, runtime_shape)
|
||||||
|
set_functorch_config()
|
||||||
|
|
||||||
# inductor can inplace modify the graph, so we need to copy it
|
# inductor can inplace modify the graph, so we need to copy it
|
||||||
# see https://github.com/pytorch/pytorch/issues/138980
|
# see https://github.com/pytorch/pytorch/issues/138980
|
||||||
@ -596,6 +598,10 @@ def set_inductor_config(config, runtime_shape):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_functorch_config():
|
||||||
|
torch._functorch.config.bundled_autograd_cache = False
|
||||||
|
|
||||||
|
|
||||||
class EagerAdaptor(CompilerInterface):
|
class EagerAdaptor(CompilerInterface):
|
||||||
name = "eager"
|
name = "eager"
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,10 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from typing import Callable, Optional, TypeVar, Union, overload
|
from typing import Callable, Optional, TypeVar, Union, overload
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -11,9 +14,10 @@ import torch.nn as nn
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
|
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
|
||||||
@ -176,6 +180,33 @@ def support_torch_compile(
|
|||||||
return cls_decorator_helper
|
return cls_decorator_helper
|
||||||
|
|
||||||
|
|
||||||
|
def _model_hash_key(fn) -> str:
|
||||||
|
import vllm
|
||||||
|
|
||||||
|
sha256_hash = hashlib.sha256()
|
||||||
|
sha256_hash.update(vllm.__version__.encode())
|
||||||
|
sha256_hash.update(fn.__qualname__.encode())
|
||||||
|
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
|
||||||
|
return sha256_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_source_unchanged(source_info, vllm_config) -> None:
|
||||||
|
from .caching import _compute_code_hash, _compute_code_hash_with_content
|
||||||
|
|
||||||
|
file_contents = {}
|
||||||
|
for source in source_info.inlined_sources:
|
||||||
|
module = sys.modules[source.module]
|
||||||
|
file = inspect.getfile(module)
|
||||||
|
vllm_config.compilation_config.traced_files.add(file)
|
||||||
|
file_contents[file] = source.content
|
||||||
|
expected_checksum = _compute_code_hash_with_content(file_contents)
|
||||||
|
actual_checksum = _compute_code_hash(set(file_contents.keys()))
|
||||||
|
if expected_checksum != actual_checksum:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Source code has changed since the last compilation. Recompiling the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _support_torch_compile(
|
def _support_torch_compile(
|
||||||
cls: _T,
|
cls: _T,
|
||||||
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
||||||
@ -227,6 +258,64 @@ def _support_torch_compile(
|
|||||||
if self.do_not_compile or torch.compiler.is_compiling():
|
if self.do_not_compile or torch.compiler.is_compiling():
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||||
|
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
cache_dir = None
|
||||||
|
aot_compilation_path = None
|
||||||
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
|
"""
|
||||||
|
When using torch.compile in AOT mode, we store the cache artifacts
|
||||||
|
under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash}
|
||||||
|
contains all of the factors except for the source files being
|
||||||
|
traced through, because we don't actually know which source files
|
||||||
|
to check at this point (before dynamo runs).
|
||||||
|
On loading we will actually look at the source files being traced
|
||||||
|
through. If any source file have changed (compared with the
|
||||||
|
serialized backend artifacts), then we need to generate a new AOT
|
||||||
|
compile artifact from scratch.
|
||||||
|
"""
|
||||||
|
from .caching import compilation_config_hash_factors
|
||||||
|
|
||||||
|
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
|
||||||
|
|
||||||
|
factors.append(_model_hash_key(self.forward))
|
||||||
|
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
|
cache_dir = os.path.join(
|
||||||
|
envs.VLLM_CACHE_ROOT,
|
||||||
|
"torch_aot_compile",
|
||||||
|
hash_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
rank = self.vllm_config.parallel_config.rank
|
||||||
|
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||||
|
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||||
|
aot_compilation_path = os.path.join(cache_dir, "model")
|
||||||
|
try:
|
||||||
|
with (
|
||||||
|
set_current_vllm_config(self.vllm_config),
|
||||||
|
open(aot_compilation_path, "rb") as f,
|
||||||
|
):
|
||||||
|
start_monitoring_torch_compile(self.vllm_config)
|
||||||
|
loaded_fn = torch.compiler.load_compiled_function(f)
|
||||||
|
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||||
|
self.aot_compiled_fn = loaded_fn
|
||||||
|
except Exception as e:
|
||||||
|
if os.path.exists(aot_compilation_path):
|
||||||
|
logger.warning(
|
||||||
|
"Cannot load aot compilation from path %s, error: %s",
|
||||||
|
aot_compilation_path,
|
||||||
|
str(e),
|
||||||
|
)
|
||||||
|
if envs.VLLM_FORCE_AOT_LOAD:
|
||||||
|
raise e
|
||||||
|
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||||
|
logger.info(
|
||||||
|
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||||
|
)
|
||||||
|
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||||
|
|
||||||
# the first compilation needs to have dynamic shapes marked
|
# the first compilation needs to have dynamic shapes marked
|
||||||
if len(self.compiled_codes) < 1:
|
if len(self.compiled_codes) < 1:
|
||||||
sig = inspect.signature(self.__class__.forward)
|
sig = inspect.signature(self.__class__.forward)
|
||||||
@ -275,15 +364,15 @@ def _support_torch_compile(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. every time Dynamo sees a function call, it will inline
|
# 2. every time Dynamo sees a function call, it will inline
|
||||||
# the function by calling InliningInstructionTranslator.inline_call
|
# the function by calling InliningInstructionTranslator.inline_call_
|
||||||
# we hijack this function to know all the functions called
|
# we hijack this function to know all the functions called
|
||||||
# during Dynamo tracing, and their corresponding files
|
# during Dynamo tracing, and their corresponding files
|
||||||
inline_call = InliningInstructionTranslator.inline_call
|
inline_call = InliningInstructionTranslator.inline_call_
|
||||||
|
|
||||||
def patched_inline_call(parent, func, args, kwargs):
|
def patched_inline_call(self_):
|
||||||
code = func.get_code()
|
code = self_.f_code
|
||||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||||
return inline_call(parent, func, args, kwargs)
|
return inline_call(self_)
|
||||||
|
|
||||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||||
# of symbolic shape guards can improve guard overhead. But, since
|
# of symbolic shape guards can improve guard overhead. But, since
|
||||||
@ -300,13 +389,21 @@ def _support_torch_compile(
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(
|
patch.object(
|
||||||
InliningInstructionTranslator, "inline_call", patched_inline_call
|
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||||
),
|
),
|
||||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||||
_torch27_patch_tensor_subclasses(),
|
_torch27_patch_tensor_subclasses(),
|
||||||
):
|
):
|
||||||
output = self.compiled_callable(*args, **kwargs)
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
|
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||||
|
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||||
|
assert aot_compilation_path is not None
|
||||||
|
assert cache_dir is not None
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
|
||||||
|
else:
|
||||||
|
output = self.compiled_callable(*args, **kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
# usually, capturing the model once is enough, and then we can
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config
|
from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -44,6 +45,19 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
options = (
|
options = (
|
||||||
get_current_vllm_config().compilation_config.inductor_compile_config
|
get_current_vllm_config().compilation_config.inductor_compile_config
|
||||||
)
|
)
|
||||||
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
|
options = options or {}
|
||||||
|
# This effectively drop all the guards.
|
||||||
|
# We need this because bytecode hook is not used any more to
|
||||||
|
# drop guards in the AOT compile mode.
|
||||||
|
options["guard_filter_fn"] = lambda guards: [False for _ in guards]
|
||||||
|
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||||
|
torch._dynamo.config.enable_aot_compile = True
|
||||||
|
else:
|
||||||
|
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||||
|
msg += "available. AOT compile is disabled and please "
|
||||||
|
msg += "upgrade PyTorch version to use AOT compile."
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
compiled_callable = torch.compile(
|
compiled_callable = torch.compile(
|
||||||
self.forward, fullgraph=True, backend=backend, options=options
|
self.forward, fullgraph=True, backend=backend, options=options
|
||||||
@ -61,6 +75,15 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def aot_compile(self, *args, **kwargs):
|
||||||
|
if not hasattr(self.compiled_callable, "aot_compile"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"aot_compile is not supported by the current configuration. "
|
||||||
|
+ "Please make sure torch.compile is enabled with the latest "
|
||||||
|
+ f"version of PyTorch (current using torch: {torch.__version__})"
|
||||||
|
)
|
||||||
|
return self.compiled_callable.aot_compile((args, kwargs))
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||||
NOTE: this function can have additional arguments beyond the forward
|
NOTE: this function can have additional arguments beyond the forward
|
||||||
|
|||||||
17
vllm/envs.py
17
vllm/envs.py
@ -89,6 +89,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
||||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
||||||
|
VLLM_USE_AOT_COMPILE: bool = False
|
||||||
|
VLLM_FORCE_AOT_LOAD: bool = False
|
||||||
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
||||||
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
||||||
VLLM_USE_TRITON_AWQ: bool = False
|
VLLM_USE_TRITON_AWQ: bool = False
|
||||||
@ -235,6 +237,13 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
|
|||||||
return bool(int(value))
|
return bool(int(value))
|
||||||
|
|
||||||
|
|
||||||
|
def use_aot_compile() -> bool:
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
|
||||||
|
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
|
||||||
|
|
||||||
|
|
||||||
def env_with_choices(
|
def env_with_choices(
|
||||||
env_name: str,
|
env_name: str,
|
||||||
default: Optional[str],
|
default: Optional[str],
|
||||||
@ -494,6 +503,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Dump fx graphs to the given directory.
|
# Dump fx graphs to the given directory.
|
||||||
# It will override CompilationConfig.debug_dump_path if set.
|
# It will override CompilationConfig.debug_dump_path if set.
|
||||||
"VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None),
|
"VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None),
|
||||||
|
# Feature flag to enable/disable AOT compilation. This will ensure
|
||||||
|
# compilation is done in warmup phase and the compilation will be
|
||||||
|
# reused in subsequent calls.
|
||||||
|
"VLLM_USE_AOT_COMPILE": use_aot_compile,
|
||||||
|
# Force vllm to always load AOT compiled models from disk. Failure
|
||||||
|
# to load will result in a hard error when this is enabled.
|
||||||
|
# Will be ignored when VLLM_USE_AOT_COMPILE is disabled.
|
||||||
|
"VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1",
|
||||||
# local rank of the process in the distributed setting, used to determine
|
# local rank of the process in the distributed setting, used to determine
|
||||||
# the GPU device id
|
# the GPU device id
|
||||||
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
|
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user