From f42b4c27d885d20afd47a09352f44ec5bd72fa34 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 24 Apr 2024 08:56:45 +0000 Subject: [PATCH] Include argmax to jit --- vllm/worker/tpu_model_runner.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2f92d6be7c0f6..f10d512791a54 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -10,6 +10,10 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import pad_to_max_length +# DELETE +# from jax_smi import initialise_tracking +# initialise_tracking() + logger = init_logger(__name__) _PAD_SLOT_ID = -1 @@ -189,7 +193,9 @@ class TPUModelRunner: kv_caches, logits_indices, ) - return logits, new_kv_caches + # TODO + next_token_ids = jnp.argmax(logits, axis=-1) + return next_token_ids, new_kv_caches def execute_model( self, @@ -199,8 +205,7 @@ class TPUModelRunner: from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob inputs = self.prepare_input_arrays(seq_group_metadata_list) - logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches) - next_token_ids = jnp.argmax(logits, axis=-1) + next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches) next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) next_token_ids = next_token_ids.tolist() i = 0