Fix JAX jit OOM

This commit is contained in:
Woosuk Kwon 2024-04-24 07:52:56 +00:00
parent 092e3d6d6d
commit d5fb1c20c1

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import jax
import jax.numpy as jnp
@ -36,12 +36,7 @@ class TPUModelRunner:
"The model will run without sliding window.")
self.model = None
self.block_size = None
# FIXME
# self.compiled_fn = jax.jit(
# self._execute_step,
# donate_argnums=6,
# )
self.compiled_fn = self._execute_step
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
def load_model(self) -> None:
from huggingface_hub import snapshot_download
@ -53,7 +48,8 @@ class TPUModelRunner:
model_name = "google/gemma-7b-flax"
model_dir = snapshot_download(model_name)
self.params = load_and_format_params(model_dir + "/7b/")
params = load_and_format_params(model_dir + "/7b/")["transformer"]
self.params = {"params": params}
def _prepare_prompt(
self,
@ -170,6 +166,7 @@ class TPUModelRunner:
def _execute_step(
self,
params: Dict[str, Any],
token_ids: jax.Array,
position_ids: jax.Array,
slot_mapping: jax.Array,
@ -177,13 +174,13 @@ class TPUModelRunner:
context_lens: Optional[jax.Array],
input_lens: jax.Array,
kv_caches: List[jax.Array],
) -> tuple[jax.Array, list[jax.Array]]:
) -> tuple[jax.Array, List[jax.Array]]:
batch_size, seq_len = token_ids.shape
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
logits_indices = base_indicies + input_lens - 1
logits, new_kv_caches = self.model.apply(
{"params": self.params["transformer"]},
params,
token_ids,
position_ids,
slot_mapping,
@ -202,7 +199,7 @@ class TPUModelRunner:
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
inputs = self.prepare_input_arrays(seq_group_metadata_list)
logits, new_kv_caches = self.compiled_fn(*inputs, kv_caches)
logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
next_token_ids = jnp.argmax(logits, axis=-1)
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
next_token_ids = next_token_ids.tolist()