[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-06-13 15:23:25 +08:00 committed by GitHub
parent ce688ad46e
commit d70bc7c029
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 117 additions and 27 deletions

View File

@ -7,6 +7,7 @@ import os
import pprint import pprint
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import torch import torch
@ -66,7 +67,25 @@ class CompilerManager:
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config) return self.compiler.compute_hash(vllm_config)
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
"""
Initialize the cache directory for the compiler.
The organization of the cache directory is as follows:
cache_dir=/path/to/hash_str/rank_i_j/prefix/
inside cache_dir, there will be:
- vllm_compile_cache.py
- computation_graph.py
- transformed_code.py
for multiple prefixes, they can share the same
base cache dir of /path/to/hash_str/rank_i_j/ ,
to store some common compilation artifacts.
"""
self.disable_cache = disable_cache self.disable_cache = disable_cache
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
@ -80,7 +99,8 @@ class CompilerManager:
self.cache = ast.literal_eval(f.read()) self.cache = ast.literal_eval(f.read())
self.compiler.initialize_cache(cache_dir=cache_dir, self.compiler.initialize_cache(cache_dir=cache_dir,
disable_cache=disable_cache) disable_cache=disable_cache,
prefix=prefix)
def save_to_file(self): def save_to_file(self):
if self.disable_cache or not self.is_cache_updated: if self.disable_cache or not self.is_cache_updated:
@ -310,6 +330,25 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
return output return output
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"
@contextmanager
def set_model_tag(tag: str):
"""Context manager to set the model tag."""
global model_tag
assert tag != model_tag, \
f"Model tag {tag} is the same as the current tag {model_tag}."
old_tag = model_tag
model_tag = tag
try:
yield
finally:
model_tag = old_tag
class VllmBackend: class VllmBackend:
"""The compilation backend for `torch.compile` with vLLM. """The compilation backend for `torch.compile` with vLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`, It is used for compilation level of `CompilationLevel.PIECEWISE`,
@ -341,7 +380,17 @@ class VllmBackend:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "",
): ):
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. launguage_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
global global_graph_pool global global_graph_pool
if global_graph_pool is None: if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle() global_graph_pool = current_platform.graph_pool_handle()
@ -441,16 +490,13 @@ class VllmBackend:
) )
self.compilation_config.cache_dir = cache_dir self.compilation_config.cache_dir = cache_dir
if compilation_counter.num_graphs_seen > 0: cache_dir = self.compilation_config.cache_dir
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir self.compilation_config.cache_dir = cache_dir
rank = vllm_config.parallel_config.rank rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
self.prefix)
os.makedirs(local_cache_dir, exist_ok=True) os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
@ -462,7 +508,8 @@ class VllmBackend:
logger.info("Using cache directory: %s for vLLM's torch.compile", logger.info("Using cache directory: %s for vLLM's torch.compile",
local_cache_dir) local_cache_dir)
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
self.prefix)
# when dynamo calls the backend, it means the bytecode # when dynamo calls the backend, it means the bytecode
# transform and analysis are done # transform and analysis are done

View File

@ -28,11 +28,22 @@ class CompilerInterface:
# This is a class-level attribute. # This is a class-level attribute.
name: str name: str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
""" """
when the vLLM process uses `cache_dir` as the cache directory, when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory, the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory. e.g. by re-directing its own cache directory to a sub-directory.
prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.
e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
""" """
pass pass
@ -166,7 +177,10 @@ class InductorStandaloneAdaptor(CompilerInterface):
usedforsecurity=False).hexdigest()[:10] usedforsecurity=False).hexdigest()[:10]
return hash_str return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
self.cache_dir = cache_dir self.cache_dir = cache_dir
def compile( def compile(
@ -242,18 +256,23 @@ class InductorAdaptor(CompilerInterface):
usedforsecurity=False).hexdigest()[:10] usedforsecurity=False).hexdigest()[:10]
return hash_str return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
if disable_cache: if disable_cache:
return return
# redirect the cache directory to a sub-directory # redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache # set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir # in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache. # to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache") inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True) os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache") triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True) os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache os.environ["TRITON_CACHE_DIR"] = triton_cache
@ -298,14 +317,14 @@ class InductorAdaptor(CompilerInterface):
nonlocal file_path nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir): if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function # hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py # in torch/_inductor/utils.py
for cell in compiled_fn.__closure__: for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents): if not callable(cell.cell_contents):
continue continue
if cell.cell_contents.__code__.co_filename.startswith( if cell.cell_contents.__code__.co_filename.startswith(
self.cache_dir): self.base_cache_dir):
# this is the real file path compiled from Inductor # this is the real file path compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename file_path = cell.cell_contents.__code__.co_filename
break break
@ -325,14 +344,15 @@ class InductorAdaptor(CompilerInterface):
nonlocal file_path nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir): if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function # hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py # in torch/_inductor/utils.py
for cell in compiled_fn.__closure__: for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents): if not callable(cell.cell_contents):
continue continue
code = cell.cell_contents.__code__ code = cell.cell_contents.__code__
if code.co_filename.startswith(self.cache_dir): if code.co_filename.startswith(
self.base_cache_dir):
# this is the real file path # this is the real file path
# compiled from Inductor # compiled from Inductor
file_path = code.co_filename file_path = code.co_filename

View File

@ -4666,10 +4666,13 @@ class VllmConfig:
_current_vllm_config: Optional[VllmConfig] = None _current_vllm_config: Optional[VllmConfig] = None
_current_prefix: Optional[str] = None
@contextmanager @contextmanager
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): def set_current_vllm_config(vllm_config: VllmConfig,
check_compile=False,
prefix: Optional[str] = None):
""" """
Temporarily set the current vLLM config. Temporarily set the current vLLM config.
Used during model initialization. Used during model initialization.
@ -4677,12 +4680,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
so that all modules can access it, e.g. custom ops so that all modules can access it, e.g. custom ops
can access the vLLM config to determine how to dispatch. can access the vLLM config to determine how to dispatch.
""" """
global _current_vllm_config global _current_vllm_config, _current_prefix
old_vllm_config = _current_vllm_config old_vllm_config = _current_vllm_config
old_prefix = _current_prefix
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
num_models_seen = compilation_counter.num_models_seen num_models_seen = compilation_counter.num_models_seen
try: try:
_current_vllm_config = vllm_config _current_vllm_config = vllm_config
_current_prefix = prefix
yield yield
except Exception: except Exception:
raise raise
@ -4706,6 +4711,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
vllm_config.model_config.model) vllm_config.model_config.model)
finally: finally:
_current_vllm_config = old_vllm_config _current_vllm_config = old_vllm_config
_current_prefix = old_prefix
def get_current_vllm_config() -> VllmConfig: def get_current_vllm_config() -> VllmConfig:
@ -4719,6 +4725,15 @@ def get_current_vllm_config() -> VllmConfig:
return _current_vllm_config return _current_vllm_config
def get_current_model_prefix() -> str:
"""
Get the prefix of the model that's currently being initialized.
"""
assert _current_prefix is not None, \
"Current model prefix is not set. "
return _current_prefix
def contains_object_print(text): def contains_object_print(text):
""" """
Check if the text looks like a printed Python object, e.g. Check if the text looks like a printed Python object, e.g.

View File

@ -58,7 +58,9 @@ def initialize_model(
all_params = [param.name for param in signatures.parameters.values()] all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params: if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class # new-style model class
with set_current_vllm_config(vllm_config, check_compile=True): with set_current_vllm_config(vllm_config,
check_compile=True,
prefix=prefix):
return model_class(vllm_config=vllm_config, prefix=prefix) return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as " msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
@ -86,7 +88,9 @@ def initialize_model(
kwargs["lora_config"] = vllm_config.lora_config kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params: if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True): with set_current_vllm_config(vllm_config,
check_compile=True,
prefix=prefix):
return model_class(**kwargs) return model_class(**kwargs)

View File

@ -320,8 +320,10 @@ class EagleProposer:
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()) get_layers_from_vllm_config(self.vllm_config, Attention).keys())
self.model = get_model(vllm_config=self.vllm_config, from vllm.compilation.backends import set_model_tag
model_config=draft_model_config) with set_model_tag("eagle_head"):
self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() - get_layers_from_vllm_config(self.vllm_config, Attention).keys() -

View File

@ -48,9 +48,11 @@ class MedusaProposer:
return [list(row) for row in zip(*draft_tokens)] return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
self.model = get_model(vllm_config=self.vllm_config, from vllm.compilation.backends import set_model_tag
model_config=self.vllm_config. with set_model_tag("medusa_head"):
speculative_config.draft_model_config) self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None: def dummy_run(self, num_tokens: int) -> None: