mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 17:24:28 +08:00
[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
ce688ad46e
commit
d70bc7c029
@ -7,6 +7,7 @@ import os
|
|||||||
import pprint
|
import pprint
|
||||||
import time
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -66,7 +67,25 @@ class CompilerManager:
|
|||||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
return self.compiler.compute_hash(vllm_config)
|
return self.compiler.compute_hash(vllm_config)
|
||||||
|
|
||||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
def initialize_cache(self,
|
||||||
|
cache_dir: str,
|
||||||
|
disable_cache: bool = False,
|
||||||
|
prefix: str = ""):
|
||||||
|
"""
|
||||||
|
Initialize the cache directory for the compiler.
|
||||||
|
|
||||||
|
The organization of the cache directory is as follows:
|
||||||
|
cache_dir=/path/to/hash_str/rank_i_j/prefix/
|
||||||
|
inside cache_dir, there will be:
|
||||||
|
- vllm_compile_cache.py
|
||||||
|
- computation_graph.py
|
||||||
|
- transformed_code.py
|
||||||
|
|
||||||
|
for multiple prefixes, they can share the same
|
||||||
|
base cache dir of /path/to/hash_str/rank_i_j/ ,
|
||||||
|
to store some common compilation artifacts.
|
||||||
|
"""
|
||||||
|
|
||||||
self.disable_cache = disable_cache
|
self.disable_cache = disable_cache
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||||
@ -80,7 +99,8 @@ class CompilerManager:
|
|||||||
self.cache = ast.literal_eval(f.read())
|
self.cache = ast.literal_eval(f.read())
|
||||||
|
|
||||||
self.compiler.initialize_cache(cache_dir=cache_dir,
|
self.compiler.initialize_cache(cache_dir=cache_dir,
|
||||||
disable_cache=disable_cache)
|
disable_cache=disable_cache,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
def save_to_file(self):
|
def save_to_file(self):
|
||||||
if self.disable_cache or not self.is_cache_updated:
|
if self.disable_cache or not self.is_cache_updated:
|
||||||
@ -310,6 +330,25 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# the tag for the part of model being compiled,
|
||||||
|
# e.g. backbone/eagle_head
|
||||||
|
model_tag: str = "backbone"
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_model_tag(tag: str):
|
||||||
|
"""Context manager to set the model tag."""
|
||||||
|
global model_tag
|
||||||
|
assert tag != model_tag, \
|
||||||
|
f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||||
|
old_tag = model_tag
|
||||||
|
model_tag = tag
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
model_tag = old_tag
|
||||||
|
|
||||||
|
|
||||||
class VllmBackend:
|
class VllmBackend:
|
||||||
"""The compilation backend for `torch.compile` with vLLM.
|
"""The compilation backend for `torch.compile` with vLLM.
|
||||||
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||||
@ -341,7 +380,17 @@ class VllmBackend:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
|
|
||||||
|
# if the model is initialized with a non-empty prefix,
|
||||||
|
# then usually it's enough to use that prefix,
|
||||||
|
# e.g. launguage_model, vision_model, etc.
|
||||||
|
# when multiple parts are initialized as independent
|
||||||
|
# models, we need to use the model_tag to distinguish
|
||||||
|
# them, e.g. backbone (default), eagle_head, etc.
|
||||||
|
self.prefix = prefix or model_tag
|
||||||
|
|
||||||
global global_graph_pool
|
global global_graph_pool
|
||||||
if global_graph_pool is None:
|
if global_graph_pool is None:
|
||||||
global_graph_pool = current_platform.graph_pool_handle()
|
global_graph_pool = current_platform.graph_pool_handle()
|
||||||
@ -441,16 +490,13 @@ class VllmBackend:
|
|||||||
)
|
)
|
||||||
self.compilation_config.cache_dir = cache_dir
|
self.compilation_config.cache_dir = cache_dir
|
||||||
|
|
||||||
if compilation_counter.num_graphs_seen > 0:
|
cache_dir = self.compilation_config.cache_dir
|
||||||
cache_dir = self.compilation_config.cache_dir + \
|
|
||||||
f'-{compilation_counter.num_graphs_seen}'
|
|
||||||
else:
|
|
||||||
cache_dir = self.compilation_config.cache_dir
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
self.compilation_config.cache_dir = cache_dir
|
self.compilation_config.cache_dir = cache_dir
|
||||||
rank = vllm_config.parallel_config.rank
|
rank = vllm_config.parallel_config.rank
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
|
||||||
|
self.prefix)
|
||||||
os.makedirs(local_cache_dir, exist_ok=True)
|
os.makedirs(local_cache_dir, exist_ok=True)
|
||||||
self.compilation_config.local_cache_dir = local_cache_dir
|
self.compilation_config.local_cache_dir = local_cache_dir
|
||||||
|
|
||||||
@ -462,7 +508,8 @@ class VllmBackend:
|
|||||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||||
local_cache_dir)
|
local_cache_dir)
|
||||||
|
|
||||||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
|
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
|
||||||
|
self.prefix)
|
||||||
|
|
||||||
# when dynamo calls the backend, it means the bytecode
|
# when dynamo calls the backend, it means the bytecode
|
||||||
# transform and analysis are done
|
# transform and analysis are done
|
||||||
|
|||||||
@ -28,11 +28,22 @@ class CompilerInterface:
|
|||||||
# This is a class-level attribute.
|
# This is a class-level attribute.
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
def initialize_cache(self,
|
||||||
|
cache_dir: str,
|
||||||
|
disable_cache: bool = False,
|
||||||
|
prefix: str = ""):
|
||||||
"""
|
"""
|
||||||
when the vLLM process uses `cache_dir` as the cache directory,
|
when the vLLM process uses `cache_dir` as the cache directory,
|
||||||
the compiler should initialize itself with the cache directory,
|
the compiler should initialize itself with the cache directory,
|
||||||
e.g. by re-directing its own cache directory to a sub-directory.
|
e.g. by re-directing its own cache directory to a sub-directory.
|
||||||
|
|
||||||
|
prefix can be used in combination with cache_dir to figure out the base
|
||||||
|
cache directory, e.g. there're multiple parts of model being compiled,
|
||||||
|
but we want to share the same cache directory for all of them.
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||||
|
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -166,7 +177,10 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
usedforsecurity=False).hexdigest()[:10]
|
usedforsecurity=False).hexdigest()[:10]
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
def initialize_cache(self,
|
||||||
|
cache_dir: str,
|
||||||
|
disable_cache: bool = False,
|
||||||
|
prefix: str = ""):
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
|
|
||||||
def compile(
|
def compile(
|
||||||
@ -242,18 +256,23 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
usedforsecurity=False).hexdigest()[:10]
|
usedforsecurity=False).hexdigest()[:10]
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
def initialize_cache(self,
|
||||||
|
cache_dir: str,
|
||||||
|
disable_cache: bool = False,
|
||||||
|
prefix: str = ""):
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
|
self.prefix = prefix
|
||||||
|
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
|
||||||
if disable_cache:
|
if disable_cache:
|
||||||
return
|
return
|
||||||
# redirect the cache directory to a sub-directory
|
# redirect the cache directory to a sub-directory
|
||||||
# set flags so that Inductor and Triton store their cache
|
# set flags so that Inductor and Triton store their cache
|
||||||
# in the cache_dir, then users only need to copy the cache_dir
|
# in the cache_dir, then users only need to copy the cache_dir
|
||||||
# to another machine to reuse the cache.
|
# to another machine to reuse the cache.
|
||||||
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||||
os.makedirs(inductor_cache, exist_ok=True)
|
os.makedirs(inductor_cache, exist_ok=True)
|
||||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||||
triton_cache = os.path.join(cache_dir, "triton_cache")
|
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||||
os.makedirs(triton_cache, exist_ok=True)
|
os.makedirs(triton_cache, exist_ok=True)
|
||||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||||
|
|
||||||
@ -298,14 +317,14 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
nonlocal file_path
|
nonlocal file_path
|
||||||
compiled_fn = inductor_compiled_graph.current_callable
|
compiled_fn = inductor_compiled_graph.current_callable
|
||||||
file_path = compiled_fn.__code__.co_filename # noqa
|
file_path = compiled_fn.__code__.co_filename # noqa
|
||||||
if not file_path.startswith(self.cache_dir):
|
if not file_path.startswith(self.base_cache_dir):
|
||||||
# hooked in the align_inputs_from_check_idxs function
|
# hooked in the align_inputs_from_check_idxs function
|
||||||
# in torch/_inductor/utils.py
|
# in torch/_inductor/utils.py
|
||||||
for cell in compiled_fn.__closure__:
|
for cell in compiled_fn.__closure__:
|
||||||
if not callable(cell.cell_contents):
|
if not callable(cell.cell_contents):
|
||||||
continue
|
continue
|
||||||
if cell.cell_contents.__code__.co_filename.startswith(
|
if cell.cell_contents.__code__.co_filename.startswith(
|
||||||
self.cache_dir):
|
self.base_cache_dir):
|
||||||
# this is the real file path compiled from Inductor
|
# this is the real file path compiled from Inductor
|
||||||
file_path = cell.cell_contents.__code__.co_filename
|
file_path = cell.cell_contents.__code__.co_filename
|
||||||
break
|
break
|
||||||
@ -325,14 +344,15 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
nonlocal file_path
|
nonlocal file_path
|
||||||
compiled_fn = inductor_compiled_graph.current_callable
|
compiled_fn = inductor_compiled_graph.current_callable
|
||||||
file_path = compiled_fn.__code__.co_filename # noqa
|
file_path = compiled_fn.__code__.co_filename # noqa
|
||||||
if not file_path.startswith(self.cache_dir):
|
if not file_path.startswith(self.base_cache_dir):
|
||||||
# hooked in the align_inputs_from_check_idxs function
|
# hooked in the align_inputs_from_check_idxs function
|
||||||
# in torch/_inductor/utils.py
|
# in torch/_inductor/utils.py
|
||||||
for cell in compiled_fn.__closure__:
|
for cell in compiled_fn.__closure__:
|
||||||
if not callable(cell.cell_contents):
|
if not callable(cell.cell_contents):
|
||||||
continue
|
continue
|
||||||
code = cell.cell_contents.__code__
|
code = cell.cell_contents.__code__
|
||||||
if code.co_filename.startswith(self.cache_dir):
|
if code.co_filename.startswith(
|
||||||
|
self.base_cache_dir):
|
||||||
# this is the real file path
|
# this is the real file path
|
||||||
# compiled from Inductor
|
# compiled from Inductor
|
||||||
file_path = code.co_filename
|
file_path = code.co_filename
|
||||||
|
|||||||
@ -4666,10 +4666,13 @@ class VllmConfig:
|
|||||||
|
|
||||||
|
|
||||||
_current_vllm_config: Optional[VllmConfig] = None
|
_current_vllm_config: Optional[VllmConfig] = None
|
||||||
|
_current_prefix: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
def set_current_vllm_config(vllm_config: VllmConfig,
|
||||||
|
check_compile=False,
|
||||||
|
prefix: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Temporarily set the current vLLM config.
|
Temporarily set the current vLLM config.
|
||||||
Used during model initialization.
|
Used during model initialization.
|
||||||
@ -4677,12 +4680,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
|||||||
so that all modules can access it, e.g. custom ops
|
so that all modules can access it, e.g. custom ops
|
||||||
can access the vLLM config to determine how to dispatch.
|
can access the vLLM config to determine how to dispatch.
|
||||||
"""
|
"""
|
||||||
global _current_vllm_config
|
global _current_vllm_config, _current_prefix
|
||||||
old_vllm_config = _current_vllm_config
|
old_vllm_config = _current_vllm_config
|
||||||
|
old_prefix = _current_prefix
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
num_models_seen = compilation_counter.num_models_seen
|
num_models_seen = compilation_counter.num_models_seen
|
||||||
try:
|
try:
|
||||||
_current_vllm_config = vllm_config
|
_current_vllm_config = vllm_config
|
||||||
|
_current_prefix = prefix
|
||||||
yield
|
yield
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
@ -4706,6 +4711,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
|||||||
vllm_config.model_config.model)
|
vllm_config.model_config.model)
|
||||||
finally:
|
finally:
|
||||||
_current_vllm_config = old_vllm_config
|
_current_vllm_config = old_vllm_config
|
||||||
|
_current_prefix = old_prefix
|
||||||
|
|
||||||
|
|
||||||
def get_current_vllm_config() -> VllmConfig:
|
def get_current_vllm_config() -> VllmConfig:
|
||||||
@ -4719,6 +4725,15 @@ def get_current_vllm_config() -> VllmConfig:
|
|||||||
return _current_vllm_config
|
return _current_vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_model_prefix() -> str:
|
||||||
|
"""
|
||||||
|
Get the prefix of the model that's currently being initialized.
|
||||||
|
"""
|
||||||
|
assert _current_prefix is not None, \
|
||||||
|
"Current model prefix is not set. "
|
||||||
|
return _current_prefix
|
||||||
|
|
||||||
|
|
||||||
def contains_object_print(text):
|
def contains_object_print(text):
|
||||||
"""
|
"""
|
||||||
Check if the text looks like a printed Python object, e.g.
|
Check if the text looks like a printed Python object, e.g.
|
||||||
|
|||||||
@ -58,7 +58,9 @@ def initialize_model(
|
|||||||
all_params = [param.name for param in signatures.parameters.values()]
|
all_params = [param.name for param in signatures.parameters.values()]
|
||||||
if "vllm_config" in all_params and "prefix" in all_params:
|
if "vllm_config" in all_params and "prefix" in all_params:
|
||||||
# new-style model class
|
# new-style model class
|
||||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
with set_current_vllm_config(vllm_config,
|
||||||
|
check_compile=True,
|
||||||
|
prefix=prefix):
|
||||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||||
@ -86,7 +88,9 @@ def initialize_model(
|
|||||||
kwargs["lora_config"] = vllm_config.lora_config
|
kwargs["lora_config"] = vllm_config.lora_config
|
||||||
if "scheduler_config" in all_params:
|
if "scheduler_config" in all_params:
|
||||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
with set_current_vllm_config(vllm_config,
|
||||||
|
check_compile=True,
|
||||||
|
prefix=prefix):
|
||||||
return model_class(**kwargs)
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -320,8 +320,10 @@ class EagleProposer:
|
|||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
||||||
|
|
||||||
self.model = get_model(vllm_config=self.vllm_config,
|
from vllm.compilation.backends import set_model_tag
|
||||||
model_config=draft_model_config)
|
with set_model_tag("eagle_head"):
|
||||||
|
self.model = get_model(vllm_config=self.vllm_config,
|
||||||
|
model_config=draft_model_config)
|
||||||
|
|
||||||
draft_attn_layer_names = (
|
draft_attn_layer_names = (
|
||||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||||
|
|||||||
@ -48,9 +48,11 @@ class MedusaProposer:
|
|||||||
return [list(row) for row in zip(*draft_tokens)]
|
return [list(row) for row in zip(*draft_tokens)]
|
||||||
|
|
||||||
def load_model(self, target_model: nn.Module) -> None:
|
def load_model(self, target_model: nn.Module) -> None:
|
||||||
self.model = get_model(vllm_config=self.vllm_config,
|
from vllm.compilation.backends import set_model_tag
|
||||||
model_config=self.vllm_config.
|
with set_model_tag("medusa_head"):
|
||||||
speculative_config.draft_model_config)
|
self.model = get_model(vllm_config=self.vllm_config,
|
||||||
|
model_config=self.vllm_config.
|
||||||
|
speculative_config.draft_model_config)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def dummy_run(self, num_tokens: int) -> None:
|
def dummy_run(self, num_tokens: int) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user