Include argmax to jit

This commit is contained in:
Woosuk Kwon 2024-04-24 08:56:45 +00:00
parent 620e7646d3
commit f42b4c27d8

View File

@ -10,6 +10,10 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import pad_to_max_length
# DELETE
# from jax_smi import initialise_tracking
# initialise_tracking()
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
@ -189,7 +193,9 @@ class TPUModelRunner:
kv_caches,
logits_indices,
)
return logits, new_kv_caches
# TODO
next_token_ids = jnp.argmax(logits, axis=-1)
return next_token_ids, new_kv_caches
def execute_model(
self,
@ -199,8 +205,7 @@ class TPUModelRunner:
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
inputs = self.prepare_input_arrays(seq_group_metadata_list)
logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
next_token_ids = jnp.argmax(logits, axis=-1)
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
next_token_ids = next_token_ids.tolist()
i = 0