mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-14 23:37:57 +08:00
Use persistent cache
This commit is contained in:
parent
707a5f6473
commit
f6637dba18
@ -57,14 +57,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
vision_language_config=vision_language_config)
|
vision_language_config=vision_language_config)
|
||||||
self.tpu_cache = None
|
self.tpu_cache = None
|
||||||
|
|
||||||
# jax.config.update("jax_compilation_cache_dir",
|
|
||||||
# os.path.expanduser("~/.vllm/jax_cache"))
|
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
# TODO: Set random seed for JAX
|
# TODO: Set random seed for JAX
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
# Use persistent cache to avoid recompilation.
|
||||||
|
jax.config.update("jax_compilation_cache_dir",
|
||||||
|
os.path.expanduser("~/.vllm/jax_cache"))
|
||||||
|
|
||||||
# DELETE
|
# DELETE
|
||||||
from jax_smi import initialise_tracking
|
from jax_smi import initialise_tracking
|
||||||
initialise_tracking()
|
initialise_tracking()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user