Use persistent cache

This commit is contained in:
Woosuk Kwon 2024-04-26 07:09:44 +00:00
parent 707a5f6473
commit f6637dba18

View File

@ -57,14 +57,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
vision_language_config=vision_language_config) vision_language_config=vision_language_config)
self.tpu_cache = None self.tpu_cache = None
# jax.config.update("jax_compilation_cache_dir",
# os.path.expanduser("~/.vllm/jax_cache"))
def init_device(self) -> None: def init_device(self) -> None:
# Set random seed. # Set random seed.
# TODO: Set random seed for JAX # TODO: Set random seed for JAX
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
# Use persistent cache to avoid recompilation.
jax.config.update("jax_compilation_cache_dir",
os.path.expanduser("~/.vllm/jax_cache"))
# DELETE # DELETE
from jax_smi import initialise_tracking from jax_smi import initialise_tracking
initialise_tracking() initialise_tracking()