mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 00:58:45 +08:00
[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
ce688ad46e
commit
d70bc7c029
@ -7,6 +7,7 @@ import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -66,7 +67,25 @@ class CompilerManager:
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
"""
|
||||
Initialize the cache directory for the compiler.
|
||||
|
||||
The organization of the cache directory is as follows:
|
||||
cache_dir=/path/to/hash_str/rank_i_j/prefix/
|
||||
inside cache_dir, there will be:
|
||||
- vllm_compile_cache.py
|
||||
- computation_graph.py
|
||||
- transformed_code.py
|
||||
|
||||
for multiple prefixes, they can share the same
|
||||
base cache dir of /path/to/hash_str/rank_i_j/ ,
|
||||
to store some common compilation artifacts.
|
||||
"""
|
||||
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||
@ -80,7 +99,8 @@ class CompilerManager:
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
|
||||
self.compiler.initialize_cache(cache_dir=cache_dir,
|
||||
disable_cache=disable_cache)
|
||||
disable_cache=disable_cache,
|
||||
prefix=prefix)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disable_cache or not self.is_cache_updated:
|
||||
@ -310,6 +330,25 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
return output
|
||||
|
||||
|
||||
# the tag for the part of model being compiled,
|
||||
# e.g. backbone/eagle_head
|
||||
model_tag: str = "backbone"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_tag(tag: str):
|
||||
"""Context manager to set the model tag."""
|
||||
global model_tag
|
||||
assert tag != model_tag, \
|
||||
f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||
old_tag = model_tag
|
||||
model_tag = tag
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model_tag = old_tag
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""The compilation backend for `torch.compile` with vLLM.
|
||||
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||
@ -341,7 +380,17 @@ class VllmBackend:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
|
||||
# if the model is initialized with a non-empty prefix,
|
||||
# then usually it's enough to use that prefix,
|
||||
# e.g. launguage_model, vision_model, etc.
|
||||
# when multiple parts are initialized as independent
|
||||
# models, we need to use the model_tag to distinguish
|
||||
# them, e.g. backbone (default), eagle_head, etc.
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
global global_graph_pool
|
||||
if global_graph_pool is None:
|
||||
global_graph_pool = current_platform.graph_pool_handle()
|
||||
@ -441,16 +490,13 @@ class VllmBackend:
|
||||
)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
|
||||
if compilation_counter.num_graphs_seen > 0:
|
||||
cache_dir = self.compilation_config.cache_dir + \
|
||||
f'-{compilation_counter.num_graphs_seen}'
|
||||
else:
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
rank = vllm_config.parallel_config.rank
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
|
||||
self.prefix)
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
@ -462,7 +508,8 @@ class VllmBackend:
|
||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||
local_cache_dir)
|
||||
|
||||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
|
||||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
|
||||
self.prefix)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
|
||||
@ -28,11 +28,22 @@ class CompilerInterface:
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -166,7 +177,10 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
@ -242,18 +256,23 @@ class InductorAdaptor(CompilerInterface):
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a sub-directory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(cache_dir, "triton_cache")
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
@ -298,14 +317,14 @@ class InductorAdaptor(CompilerInterface):
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.cache_dir):
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.cache_dir):
|
||||
self.base_cache_dir):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
break
|
||||
@ -325,14 +344,15 @@ class InductorAdaptor(CompilerInterface):
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.cache_dir):
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.cache_dir):
|
||||
if code.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
|
||||
@ -4666,10 +4666,13 @@ class VllmConfig:
|
||||
|
||||
|
||||
_current_vllm_config: Optional[VllmConfig] = None
|
||||
_current_prefix: Optional[str] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||
def set_current_vllm_config(vllm_config: VllmConfig,
|
||||
check_compile=False,
|
||||
prefix: Optional[str] = None):
|
||||
"""
|
||||
Temporarily set the current vLLM config.
|
||||
Used during model initialization.
|
||||
@ -4677,12 +4680,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||
so that all modules can access it, e.g. custom ops
|
||||
can access the vLLM config to determine how to dispatch.
|
||||
"""
|
||||
global _current_vllm_config
|
||||
global _current_vllm_config, _current_prefix
|
||||
old_vllm_config = _current_vllm_config
|
||||
old_prefix = _current_prefix
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
num_models_seen = compilation_counter.num_models_seen
|
||||
try:
|
||||
_current_vllm_config = vllm_config
|
||||
_current_prefix = prefix
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
@ -4706,6 +4711,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||
vllm_config.model_config.model)
|
||||
finally:
|
||||
_current_vllm_config = old_vllm_config
|
||||
_current_prefix = old_prefix
|
||||
|
||||
|
||||
def get_current_vllm_config() -> VllmConfig:
|
||||
@ -4719,6 +4725,15 @@ def get_current_vllm_config() -> VllmConfig:
|
||||
return _current_vllm_config
|
||||
|
||||
|
||||
def get_current_model_prefix() -> str:
|
||||
"""
|
||||
Get the prefix of the model that's currently being initialized.
|
||||
"""
|
||||
assert _current_prefix is not None, \
|
||||
"Current model prefix is not set. "
|
||||
return _current_prefix
|
||||
|
||||
|
||||
def contains_object_print(text):
|
||||
"""
|
||||
Check if the text looks like a printed Python object, e.g.
|
||||
|
||||
@ -58,7 +58,9 @@ def initialize_model(
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||
with set_current_vllm_config(vllm_config,
|
||||
check_compile=True,
|
||||
prefix=prefix):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
@ -86,7 +88,9 @@ def initialize_model(
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||
with set_current_vllm_config(vllm_config,
|
||||
check_compile=True,
|
||||
prefix=prefix):
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
|
||||
@ -320,8 +320,10 @@ class EagleProposer:
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
||||
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=draft_model_config)
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
with set_model_tag("eagle_head"):
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=draft_model_config)
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||
|
||||
@ -48,9 +48,11 @@ class MedusaProposer:
|
||||
return [list(row) for row in zip(*draft_tokens)]
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
with set_model_tag("medusa_head"):
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self, num_tokens: int) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user