Move JAX-smi to worker

This commit is contained in:
Woosuk Kwon 2024-04-26 07:05:51 +00:00
parent 57690a9c09
commit 707a5f6473

View File

@ -1,5 +1,7 @@
import os
from typing import Dict, List, Optional, Tuple
import jax
import jax.numpy as jnp
import torch
@ -55,10 +57,17 @@ class TPUWorker(LoraNotSupportedWorkerBase):
vision_language_config=vision_language_config)
self.tpu_cache = None
# jax.config.update("jax_compilation_cache_dir",
# os.path.expanduser("~/.vllm/jax_cache"))
def init_device(self) -> None:
# Set random seed.
# TODO: Set random seed for JAX
set_random_seed(self.model_config.seed)
# TODO: JAX
# DELETE
from jax_smi import initialise_tracking
initialise_tracking()
def load_model(self):
self.model_runner.load_model()