From 707a5f6473d6f9301830b9842fb2738a987155b6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 07:05:51 +0000 Subject: [PATCH] Move JAX-smi to worker --- vllm/worker/tpu_worker.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 46f447b26adb9..f67f781f7c109 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,5 +1,7 @@ +import os from typing import Dict, List, Optional, Tuple +import jax import jax.numpy as jnp import torch @@ -55,10 +57,17 @@ class TPUWorker(LoraNotSupportedWorkerBase): vision_language_config=vision_language_config) self.tpu_cache = None + # jax.config.update("jax_compilation_cache_dir", + # os.path.expanduser("~/.vllm/jax_cache")) + def init_device(self) -> None: # Set random seed. + # TODO: Set random seed for JAX set_random_seed(self.model_config.seed) - # TODO: JAX + + # DELETE + from jax_smi import initialise_tracking + initialise_tracking() def load_model(self): self.model_runner.load_model()