Fix JAX jit OOM

This commit is contained in:
Woosuk Kwon 2024-04-24 07:52:56 +00:00
parent 092e3d6d6d
commit d5fb1c20c1

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -36,12 +36,7 @@ class TPUModelRunner:
"The model will run without sliding window.") "The model will run without sliding window.")
self.model = None self.model = None
self.block_size = None self.block_size = None
# FIXME self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
# self.compiled_fn = jax.jit(
# self._execute_step,
# donate_argnums=6,
# )
self.compiled_fn = self._execute_step
def load_model(self) -> None: def load_model(self) -> None:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -53,7 +48,8 @@ class TPUModelRunner:
model_name = "google/gemma-7b-flax" model_name = "google/gemma-7b-flax"
model_dir = snapshot_download(model_name) model_dir = snapshot_download(model_name)
self.params = load_and_format_params(model_dir + "/7b/") params = load_and_format_params(model_dir + "/7b/")["transformer"]
self.params = {"params": params}
def _prepare_prompt( def _prepare_prompt(
self, self,
@ -170,6 +166,7 @@ class TPUModelRunner:
def _execute_step( def _execute_step(
self, self,
params: Dict[str, Any],
token_ids: jax.Array, token_ids: jax.Array,
position_ids: jax.Array, position_ids: jax.Array,
slot_mapping: jax.Array, slot_mapping: jax.Array,
@ -177,13 +174,13 @@ class TPUModelRunner:
context_lens: Optional[jax.Array], context_lens: Optional[jax.Array],
input_lens: jax.Array, input_lens: jax.Array,
kv_caches: List[jax.Array], kv_caches: List[jax.Array],
) -> tuple[jax.Array, list[jax.Array]]: ) -> tuple[jax.Array, List[jax.Array]]:
batch_size, seq_len = token_ids.shape batch_size, seq_len = token_ids.shape
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
logits_indices = base_indicies + input_lens - 1 logits_indices = base_indicies + input_lens - 1
logits, new_kv_caches = self.model.apply( logits, new_kv_caches = self.model.apply(
{"params": self.params["transformer"]}, params,
token_ids, token_ids,
position_ids, position_ids,
slot_mapping, slot_mapping,
@ -202,7 +199,7 @@ class TPUModelRunner:
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
inputs = self.prepare_input_arrays(seq_group_metadata_list) inputs = self.prepare_input_arrays(seq_group_metadata_list)
logits, new_kv_caches = self.compiled_fn(*inputs, kv_caches) logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
next_token_ids = jnp.argmax(logits, axis=-1) next_token_ids = jnp.argmax(logits, axis=-1)
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()