[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-06-13 15:23:25 +08:00 committed by GitHub
parent ce688ad46e
commit d70bc7c029
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 117 additions and 27 deletions

View File

@ -7,6 +7,7 @@ import os
import pprint
import time
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Callable, Optional
import torch
@ -66,7 +67,25 @@ class CompilerManager:
def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
"""
Initialize the cache directory for the compiler.
The organization of the cache directory is as follows:
cache_dir=/path/to/hash_str/rank_i_j/prefix/
inside cache_dir, there will be:
- vllm_compile_cache.py
- computation_graph.py
- transformed_code.py
for multiple prefixes, they can share the same
base cache dir of /path/to/hash_str/rank_i_j/ ,
to store some common compilation artifacts.
"""
self.disable_cache = disable_cache
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
@ -80,7 +99,8 @@ class CompilerManager:
self.cache = ast.literal_eval(f.read())
self.compiler.initialize_cache(cache_dir=cache_dir,
disable_cache=disable_cache)
disable_cache=disable_cache,
prefix=prefix)
def save_to_file(self):
if self.disable_cache or not self.is_cache_updated:
@ -310,6 +330,25 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
return output
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"
@contextmanager
def set_model_tag(tag: str):
"""Context manager to set the model tag."""
global model_tag
assert tag != model_tag, \
f"Model tag {tag} is the same as the current tag {model_tag}."
old_tag = model_tag
model_tag = tag
try:
yield
finally:
model_tag = old_tag
class VllmBackend:
"""The compilation backend for `torch.compile` with vLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@ -341,7 +380,17 @@ class VllmBackend:
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. launguage_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
@ -441,16 +490,13 @@ class VllmBackend:
)
self.compilation_config.cache_dir = cache_dir
if compilation_counter.num_graphs_seen > 0:
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
else:
cache_dir = self.compilation_config.cache_dir
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
self.prefix)
os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir
@ -462,7 +508,8 @@ class VllmBackend:
logger.info("Using cache directory: %s for vLLM's torch.compile",
local_cache_dir)
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
self.prefix)
# when dynamo calls the backend, it means the bytecode
# transform and analysis are done

View File

@ -28,11 +28,22 @@ class CompilerInterface:
# This is a class-level attribute.
name: str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.
e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
"""
pass
@ -166,7 +177,10 @@ class InductorStandaloneAdaptor(CompilerInterface):
usedforsecurity=False).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
self.cache_dir = cache_dir
def compile(
@ -242,18 +256,23 @@ class InductorAdaptor(CompilerInterface):
usedforsecurity=False).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
if disable_cache:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
@ -298,14 +317,14 @@ class InductorAdaptor(CompilerInterface):
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir):
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
if cell.cell_contents.__code__.co_filename.startswith(
self.cache_dir):
self.base_cache_dir):
# this is the real file path compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename
break
@ -325,14 +344,15 @@ class InductorAdaptor(CompilerInterface):
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir):
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(self.cache_dir):
if code.co_filename.startswith(
self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename

View File

@ -4666,10 +4666,13 @@ class VllmConfig:
_current_vllm_config: Optional[VllmConfig] = None
_current_prefix: Optional[str] = None
@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
def set_current_vllm_config(vllm_config: VllmConfig,
check_compile=False,
prefix: Optional[str] = None):
"""
Temporarily set the current vLLM config.
Used during model initialization.
@ -4677,12 +4680,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
so that all modules can access it, e.g. custom ops
can access the vLLM config to determine how to dispatch.
"""
global _current_vllm_config
global _current_vllm_config, _current_prefix
old_vllm_config = _current_vllm_config
old_prefix = _current_prefix
from vllm.compilation.counter import compilation_counter
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
_current_prefix = prefix
yield
except Exception:
raise
@ -4706,6 +4711,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
vllm_config.model_config.model)
finally:
_current_vllm_config = old_vllm_config
_current_prefix = old_prefix
def get_current_vllm_config() -> VllmConfig:
@ -4719,6 +4725,15 @@ def get_current_vllm_config() -> VllmConfig:
return _current_vllm_config
def get_current_model_prefix() -> str:
"""
Get the prefix of the model that's currently being initialized.
"""
assert _current_prefix is not None, \
"Current model prefix is not set. "
return _current_prefix
def contains_object_print(text):
"""
Check if the text looks like a printed Python object, e.g.

View File

@ -58,7 +58,9 @@ def initialize_model(
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True):
with set_current_vllm_config(vllm_config,
check_compile=True,
prefix=prefix):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
@ -86,7 +88,9 @@ def initialize_model(
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True):
with set_current_vllm_config(vllm_config,
check_compile=True,
prefix=prefix):
return model_class(**kwargs)

View File

@ -320,8 +320,10 @@ class EagleProposer:
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -

View File

@ -48,9 +48,11 @@ class MedusaProposer:
return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None:
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
from vllm.compilation.backends import set_model_tag
with set_model_tag("medusa_head"):
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None: