diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index f67f781f7c109..4e58830865758 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -57,14 +57,15 @@ class TPUWorker(LoraNotSupportedWorkerBase): vision_language_config=vision_language_config) self.tpu_cache = None - # jax.config.update("jax_compilation_cache_dir", - # os.path.expanduser("~/.vllm/jax_cache")) - def init_device(self) -> None: # Set random seed. # TODO: Set random seed for JAX set_random_seed(self.model_config.seed) + # Use persistent cache to avoid recompilation. + jax.config.update("jax_compilation_cache_dir", + os.path.expanduser("~/.vllm/jax_cache")) + # DELETE from jax_smi import initialise_tracking initialise_tracking()