From d70bc7c02957ff6c11bc117527f7c724bb6aeb46 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 13 Jun 2025 15:23:25 +0800 Subject: [PATCH] [torch.compile] reorganize the cache directory to support compiling multiple models (#19064) Signed-off-by: youkaichao --- vllm/compilation/backends.py | 65 +++++++++++++++++++---- vllm/compilation/compiler_interface.py | 38 +++++++++---- vllm/config.py | 19 ++++++- vllm/model_executor/model_loader/utils.py | 8 ++- vllm/v1/spec_decode/eagle.py | 6 ++- vllm/v1/spec_decode/medusa.py | 8 +-- 6 files changed, 117 insertions(+), 27 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 90795c29ebab7..8bb8c3a2a2e4e 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -7,6 +7,7 @@ import os import pprint import time from collections.abc import Sequence +from contextlib import contextmanager from typing import Any, Callable, Optional import torch @@ -66,7 +67,25 @@ class CompilerManager: def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) - def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): + """ + Initialize the cache directory for the compiler. + + The organization of the cache directory is as follows: + cache_dir=/path/to/hash_str/rank_i_j/prefix/ + inside cache_dir, there will be: + - vllm_compile_cache.py + - computation_graph.py + - transformed_code.py + + for multiple prefixes, they can share the same + base cache dir of /path/to/hash_str/rank_i_j/ , + to store some common compilation artifacts. + """ + self.disable_cache = disable_cache self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") @@ -80,7 +99,8 @@ class CompilerManager: self.cache = ast.literal_eval(f.read()) self.compiler.initialize_cache(cache_dir=cache_dir, - disable_cache=disable_cache) + disable_cache=disable_cache, + prefix=prefix) def save_to_file(self): if self.disable_cache or not self.is_cache_updated: @@ -310,6 +330,25 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): return output +# the tag for the part of model being compiled, +# e.g. backbone/eagle_head +model_tag: str = "backbone" + + +@contextmanager +def set_model_tag(tag: str): + """Context manager to set the model tag.""" + global model_tag + assert tag != model_tag, \ + f"Model tag {tag} is the same as the current tag {model_tag}." + old_tag = model_tag + model_tag = tag + try: + yield + finally: + model_tag = old_tag + + class VllmBackend: """The compilation backend for `torch.compile` with vLLM. It is used for compilation level of `CompilationLevel.PIECEWISE`, @@ -341,7 +380,17 @@ class VllmBackend: def __init__( self, vllm_config: VllmConfig, + prefix: str = "", ): + + # if the model is initialized with a non-empty prefix, + # then usually it's enough to use that prefix, + # e.g. launguage_model, vision_model, etc. + # when multiple parts are initialized as independent + # models, we need to use the model_tag to distinguish + # them, e.g. backbone (default), eagle_head, etc. + self.prefix = prefix or model_tag + global global_graph_pool if global_graph_pool is None: global_graph_pool = current_platform.graph_pool_handle() @@ -441,16 +490,13 @@ class VllmBackend: ) self.compilation_config.cache_dir = cache_dir - if compilation_counter.num_graphs_seen > 0: - cache_dir = self.compilation_config.cache_dir + \ - f'-{compilation_counter.num_graphs_seen}' - else: - cache_dir = self.compilation_config.cache_dir + cache_dir = self.compilation_config.cache_dir os.makedirs(cache_dir, exist_ok=True) self.compilation_config.cache_dir = cache_dir rank = vllm_config.parallel_config.rank dp_rank = vllm_config.parallel_config.data_parallel_rank - local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", + self.prefix) os.makedirs(local_cache_dir, exist_ok=True) self.compilation_config.local_cache_dir = local_cache_dir @@ -462,7 +508,8 @@ class VllmBackend: logger.info("Using cache directory: %s for vLLM's torch.compile", local_cache_dir) - self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) + self.compiler_manager.initialize_cache(local_cache_dir, disable_cache, + self.prefix) # when dynamo calls the backend, it means the bytecode # transform and analysis are done diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 36c810ec2dc96..fd39a6127d00b 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -28,11 +28,22 @@ class CompilerInterface: # This is a class-level attribute. name: str - def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): """ when the vLLM process uses `cache_dir` as the cache directory, the compiler should initialize itself with the cache directory, e.g. by re-directing its own cache directory to a sub-directory. + + prefix can be used in combination with cache_dir to figure out the base + cache directory, e.g. there're multiple parts of model being compiled, + but we want to share the same cache directory for all of them. + + e.g. + cache_dir = "/path/to/dir/backbone", prefix = "backbone" + cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head" """ pass @@ -166,7 +177,10 @@ class InductorStandaloneAdaptor(CompilerInterface): usedforsecurity=False).hexdigest()[:10] return hash_str - def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): self.cache_dir = cache_dir def compile( @@ -242,18 +256,23 @@ class InductorAdaptor(CompilerInterface): usedforsecurity=False).hexdigest()[:10] return hash_str - def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): self.cache_dir = cache_dir + self.prefix = prefix + self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir if disable_cache: return # redirect the cache directory to a sub-directory # set flags so that Inductor and Triton store their cache # in the cache_dir, then users only need to copy the cache_dir # to another machine to reuse the cache. - inductor_cache = os.path.join(cache_dir, "inductor_cache") + inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache") os.makedirs(inductor_cache, exist_ok=True) os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache - triton_cache = os.path.join(cache_dir, "triton_cache") + triton_cache = os.path.join(self.base_cache_dir, "triton_cache") os.makedirs(triton_cache, exist_ok=True) os.environ["TRITON_CACHE_DIR"] = triton_cache @@ -298,14 +317,14 @@ class InductorAdaptor(CompilerInterface): nonlocal file_path compiled_fn = inductor_compiled_graph.current_callable file_path = compiled_fn.__code__.co_filename # noqa - if not file_path.startswith(self.cache_dir): + if not file_path.startswith(self.base_cache_dir): # hooked in the align_inputs_from_check_idxs function # in torch/_inductor/utils.py for cell in compiled_fn.__closure__: if not callable(cell.cell_contents): continue if cell.cell_contents.__code__.co_filename.startswith( - self.cache_dir): + self.base_cache_dir): # this is the real file path compiled from Inductor file_path = cell.cell_contents.__code__.co_filename break @@ -325,14 +344,15 @@ class InductorAdaptor(CompilerInterface): nonlocal file_path compiled_fn = inductor_compiled_graph.current_callable file_path = compiled_fn.__code__.co_filename # noqa - if not file_path.startswith(self.cache_dir): + if not file_path.startswith(self.base_cache_dir): # hooked in the align_inputs_from_check_idxs function # in torch/_inductor/utils.py for cell in compiled_fn.__closure__: if not callable(cell.cell_contents): continue code = cell.cell_contents.__code__ - if code.co_filename.startswith(self.cache_dir): + if code.co_filename.startswith( + self.base_cache_dir): # this is the real file path # compiled from Inductor file_path = code.co_filename diff --git a/vllm/config.py b/vllm/config.py index d2cfbc8392528..bd15fcc553145 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4666,10 +4666,13 @@ class VllmConfig: _current_vllm_config: Optional[VllmConfig] = None +_current_prefix: Optional[str] = None @contextmanager -def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): +def set_current_vllm_config(vllm_config: VllmConfig, + check_compile=False, + prefix: Optional[str] = None): """ Temporarily set the current vLLM config. Used during model initialization. @@ -4677,12 +4680,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): so that all modules can access it, e.g. custom ops can access the vLLM config to determine how to dispatch. """ - global _current_vllm_config + global _current_vllm_config, _current_prefix old_vllm_config = _current_vllm_config + old_prefix = _current_prefix from vllm.compilation.counter import compilation_counter num_models_seen = compilation_counter.num_models_seen try: _current_vllm_config = vllm_config + _current_prefix = prefix yield except Exception: raise @@ -4706,6 +4711,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): vllm_config.model_config.model) finally: _current_vllm_config = old_vllm_config + _current_prefix = old_prefix def get_current_vllm_config() -> VllmConfig: @@ -4719,6 +4725,15 @@ def get_current_vllm_config() -> VllmConfig: return _current_vllm_config +def get_current_model_prefix() -> str: + """ + Get the prefix of the model that's currently being initialized. + """ + assert _current_prefix is not None, \ + "Current model prefix is not set. " + return _current_prefix + + def contains_object_print(text): """ Check if the text looks like a printed Python object, e.g. diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index e6eaade090275..79e6fa7b16dc7 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -58,7 +58,9 @@ def initialize_model( all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - with set_current_vllm_config(vllm_config, check_compile=True): + with set_current_vllm_config(vllm_config, + check_compile=True, + prefix=prefix): return model_class(vllm_config=vllm_config, prefix=prefix) msg = ("vLLM model class should accept `vllm_config` and `prefix` as " @@ -86,7 +88,9 @@ def initialize_model( kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - with set_current_vllm_config(vllm_config, check_compile=True): + with set_current_vllm_config(vllm_config, + check_compile=True, + prefix=prefix): return model_class(**kwargs) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f7179385ebb74..7b550739a83da 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -320,8 +320,10 @@ class EagleProposer: target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index f516bf486b8b5..309fd926aecd7 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -48,9 +48,11 @@ class MedusaProposer: return [list(row) for row in zip(*draft_tokens)] def load_model(self, target_model: nn.Module) -> None: - self.model = get_model(vllm_config=self.vllm_config, - model_config=self.vllm_config. - speculative_config.draft_model_config) + from vllm.compilation.backends import set_model_tag + with set_model_tag("medusa_head"): + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) @torch.inference_mode() def dummy_run(self, num_tokens: int) -> None: