mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 15:11:25 +08:00
Fix JAX jit OOM
This commit is contained in:
parent
092e3d6d6d
commit
d5fb1c20c1
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -36,12 +36,7 @@ class TPUModelRunner:
|
|||||||
"The model will run without sliding window.")
|
"The model will run without sliding window.")
|
||||||
self.model = None
|
self.model = None
|
||||||
self.block_size = None
|
self.block_size = None
|
||||||
# FIXME
|
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
|
||||||
# self.compiled_fn = jax.jit(
|
|
||||||
# self._execute_step,
|
|
||||||
# donate_argnums=6,
|
|
||||||
# )
|
|
||||||
self.compiled_fn = self._execute_step
|
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@ -53,7 +48,8 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
model_name = "google/gemma-7b-flax"
|
model_name = "google/gemma-7b-flax"
|
||||||
model_dir = snapshot_download(model_name)
|
model_dir = snapshot_download(model_name)
|
||||||
self.params = load_and_format_params(model_dir + "/7b/")
|
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||||
|
self.params = {"params": params}
|
||||||
|
|
||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
@ -170,6 +166,7 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
def _execute_step(
|
def _execute_step(
|
||||||
self,
|
self,
|
||||||
|
params: Dict[str, Any],
|
||||||
token_ids: jax.Array,
|
token_ids: jax.Array,
|
||||||
position_ids: jax.Array,
|
position_ids: jax.Array,
|
||||||
slot_mapping: jax.Array,
|
slot_mapping: jax.Array,
|
||||||
@ -177,13 +174,13 @@ class TPUModelRunner:
|
|||||||
context_lens: Optional[jax.Array],
|
context_lens: Optional[jax.Array],
|
||||||
input_lens: jax.Array,
|
input_lens: jax.Array,
|
||||||
kv_caches: List[jax.Array],
|
kv_caches: List[jax.Array],
|
||||||
) -> tuple[jax.Array, list[jax.Array]]:
|
) -> tuple[jax.Array, List[jax.Array]]:
|
||||||
batch_size, seq_len = token_ids.shape
|
batch_size, seq_len = token_ids.shape
|
||||||
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
|
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
|
||||||
logits_indices = base_indicies + input_lens - 1
|
logits_indices = base_indicies + input_lens - 1
|
||||||
|
|
||||||
logits, new_kv_caches = self.model.apply(
|
logits, new_kv_caches = self.model.apply(
|
||||||
{"params": self.params["transformer"]},
|
params,
|
||||||
token_ids,
|
token_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
slot_mapping,
|
slot_mapping,
|
||||||
@ -202,7 +199,7 @@ class TPUModelRunner:
|
|||||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||||
|
|
||||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||||
logits, new_kv_caches = self.compiled_fn(*inputs, kv_caches)
|
logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
|
||||||
next_token_ids = jnp.argmax(logits, axis=-1)
|
next_token_ids = jnp.argmax(logits, axis=-1)
|
||||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user