diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 46f447b26adb9..f67f781f7c109 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,5 +1,7 @@ +import os from typing import Dict, List, Optional, Tuple +import jax import jax.numpy as jnp import torch @@ -55,10 +57,17 @@ class TPUWorker(LoraNotSupportedWorkerBase): vision_language_config=vision_language_config) self.tpu_cache = None + # jax.config.update("jax_compilation_cache_dir", + # os.path.expanduser("~/.vllm/jax_cache")) + def init_device(self) -> None: # Set random seed. + # TODO: Set random seed for JAX set_random_seed(self.model_config.seed) - # TODO: JAX + + # DELETE + from jax_smi import initialise_tracking + initialise_tracking() def load_model(self): self.model_runner.load_model()