mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-20 17:17:03 +08:00
Move JAX-smi to worker
This commit is contained in:
parent
57690a9c09
commit
707a5f6473
@ -1,5 +1,7 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
|
||||
@ -55,10 +57,17 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
vision_language_config=vision_language_config)
|
||||
self.tpu_cache = None
|
||||
|
||||
# jax.config.update("jax_compilation_cache_dir",
|
||||
# os.path.expanduser("~/.vllm/jax_cache"))
|
||||
|
||||
def init_device(self) -> None:
|
||||
# Set random seed.
|
||||
# TODO: Set random seed for JAX
|
||||
set_random_seed(self.model_config.seed)
|
||||
# TODO: JAX
|
||||
|
||||
# DELETE
|
||||
from jax_smi import initialise_tracking
|
||||
initialise_tracking()
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user