[torch.compile] transparent compilation with more logging (#12246)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-21 19:32:55 +08:00 committed by GitHub
parent a94eee4456
commit c81081fece
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 7 deletions

View File

@ -524,6 +524,7 @@ class VllmBackend:
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
vllm_config = self.vllm_config
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
@ -532,7 +533,6 @@ class VllmBackend:
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config = self.vllm_config
config_hash = vllm_config.compute_hash()
# 2. factors come from the code files that are traced by Dynamo (
@ -556,20 +556,26 @@ class VllmBackend:
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
else:
cache_dir = self.compilation_config.cache_dir
envs.VLLM_CACHE_ROOT,
"torch_compile_cache",
hash_key,
)
self.compilation_config.cache_dir = cache_dir
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
local_cache_dir = os.path.join(
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
self.compilation_config.local_cache_dir = local_cache_dir
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
cache_dir, disabled=disabled)
local_cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info("Using cache directory: %s for vLLM's torch.compile",
cache_dir)
local_cache_dir)
# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
@ -609,6 +615,18 @@ class VllmBackend:
self.vllm_config, self.graph_pool,
self).run(*example_inputs)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
# use `print_readable` because it can include submodules
src = "from __future__ import annotations\nimport torch\n" + \
self.split_gm.print_readable(print_output=False)
src = src.replace("<lambda>", "GraphModule")
with open(graph_path, "w") as f:
f.write(src)
logger.debug("Computation graph saved to %s", graph_path)
self._called = True
if not self.compilation_config.use_cudagraph or \

View File

@ -198,6 +198,8 @@ def _support_torch_compile(
f" {dims} for argument {k} with type {type(arg)}.")
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
logger.debug("Start compiling function %s",
self.original_code_object)
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,

View File

@ -9,6 +9,9 @@ import torch
import vllm.envs as envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger(__name__)
class TorchCompileWrapperWithCustomDispatcher:
@ -82,6 +85,25 @@ class TorchCompileWrapperWithCustomDispatcher:
return
self.compiled_codes.append(new_code)
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
if isinstance(local_cache_dir, str):
decompiled_file = os.path.join(local_cache_dir,
"transformed_code.py")
if not os.path.exists(decompiled_file):
try:
# usually the decompilation will succeed for most models,
# as we guarantee a full-graph compilation in Dynamo.
# but there's no 100% guarantee, since decompliation is
# not a reversible process.
import depyf
src = depyf.decompile(new_code)
with open(decompiled_file, "w") as f:
f.write(src)
logger.debug("Dynamo transformed code saved to %s",
decompiled_file)
except Exception:
pass
if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:

View File

@ -2785,6 +2785,7 @@ class CompilationConfig(BaseModel):
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
local_cache_dir: str = PrivateAttr # local cache dir for each rank
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# since we know all keys are in a range [0, max_capture_size],