From d5fb1c20c1ca1e6165e170e85fd76992f4a69742 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 24 Apr 2024 07:52:56 +0000 Subject: [PATCH] Fix JAX jit OOM --- vllm/worker/tpu_model_runner.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 3278c5dada3b3..2f92d6be7c0f6 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import jax import jax.numpy as jnp @@ -36,12 +36,7 @@ class TPUModelRunner: "The model will run without sliding window.") self.model = None self.block_size = None - # FIXME - # self.compiled_fn = jax.jit( - # self._execute_step, - # donate_argnums=6, - # ) - self.compiled_fn = self._execute_step + self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,)) def load_model(self) -> None: from huggingface_hub import snapshot_download @@ -53,7 +48,8 @@ class TPUModelRunner: model_name = "google/gemma-7b-flax" model_dir = snapshot_download(model_name) - self.params = load_and_format_params(model_dir + "/7b/") + params = load_and_format_params(model_dir + "/7b/")["transformer"] + self.params = {"params": params} def _prepare_prompt( self, @@ -170,6 +166,7 @@ class TPUModelRunner: def _execute_step( self, + params: Dict[str, Any], token_ids: jax.Array, position_ids: jax.Array, slot_mapping: jax.Array, @@ -177,13 +174,13 @@ class TPUModelRunner: context_lens: Optional[jax.Array], input_lens: jax.Array, kv_caches: List[jax.Array], - ) -> tuple[jax.Array, list[jax.Array]]: + ) -> tuple[jax.Array, List[jax.Array]]: batch_size, seq_len = token_ids.shape base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len logits_indices = base_indicies + input_lens - 1 logits, new_kv_caches = self.model.apply( - {"params": self.params["transformer"]}, + params, token_ids, position_ids, slot_mapping, @@ -202,7 +199,7 @@ class TPUModelRunner: from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob inputs = self.prepare_input_arrays(seq_group_metadata_list) - logits, new_kv_caches = self.compiled_fn(*inputs, kv_caches) + logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches) next_token_ids = jnp.argmax(logits, axis=-1) next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) next_token_ids = next_token_ids.tolist()