mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:27:13 +08:00
Fix JAX jit OOM
This commit is contained in:
parent
092e3d6d6d
commit
d5fb1c20c1
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -36,12 +36,7 @@ class TPUModelRunner:
|
||||
"The model will run without sliding window.")
|
||||
self.model = None
|
||||
self.block_size = None
|
||||
# FIXME
|
||||
# self.compiled_fn = jax.jit(
|
||||
# self._execute_step,
|
||||
# donate_argnums=6,
|
||||
# )
|
||||
self.compiled_fn = self._execute_step
|
||||
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
|
||||
|
||||
def load_model(self) -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
@ -53,7 +48,8 @@ class TPUModelRunner:
|
||||
|
||||
model_name = "google/gemma-7b-flax"
|
||||
model_dir = snapshot_download(model_name)
|
||||
self.params = load_and_format_params(model_dir + "/7b/")
|
||||
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||
self.params = {"params": params}
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
@ -170,6 +166,7 @@ class TPUModelRunner:
|
||||
|
||||
def _execute_step(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
token_ids: jax.Array,
|
||||
position_ids: jax.Array,
|
||||
slot_mapping: jax.Array,
|
||||
@ -177,13 +174,13 @@ class TPUModelRunner:
|
||||
context_lens: Optional[jax.Array],
|
||||
input_lens: jax.Array,
|
||||
kv_caches: List[jax.Array],
|
||||
) -> tuple[jax.Array, list[jax.Array]]:
|
||||
) -> tuple[jax.Array, List[jax.Array]]:
|
||||
batch_size, seq_len = token_ids.shape
|
||||
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
|
||||
logits_indices = base_indicies + input_lens - 1
|
||||
|
||||
logits, new_kv_caches = self.model.apply(
|
||||
{"params": self.params["transformer"]},
|
||||
params,
|
||||
token_ids,
|
||||
position_ids,
|
||||
slot_mapping,
|
||||
@ -202,7 +199,7 @@ class TPUModelRunner:
|
||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||
|
||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||
logits, new_kv_caches = self.compiled_fn(*inputs, kv_caches)
|
||||
logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
|
||||
next_token_ids = jnp.argmax(logits, axis=-1)
|
||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user