From f6637dba183bf5cfbce77c3b379ebf3b155fcb13 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 07:09:44 +0000 Subject: [PATCH] Use persistent cache --- vllm/worker/tpu_worker.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index f67f781f7c109..4e58830865758 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -57,14 +57,15 @@ class TPUWorker(LoraNotSupportedWorkerBase): vision_language_config=vision_language_config) self.tpu_cache = None - # jax.config.update("jax_compilation_cache_dir", - # os.path.expanduser("~/.vllm/jax_cache")) - def init_device(self) -> None: # Set random seed. # TODO: Set random seed for JAX set_random_seed(self.model_config.seed) + # Use persistent cache to avoid recompilation. + jax.config.update("jax_compilation_cache_dir", + os.path.expanduser("~/.vllm/jax_cache")) + # DELETE from jax_smi import initialise_tracking initialise_tracking()