This commit is contained in:
Woosuk Kwon 2024-04-25 03:28:53 +00:00
parent 5323969fcf
commit e2c7dedb3a

View File

@ -193,7 +193,7 @@ class TPUModelRunner:
kv_caches, kv_caches,
logits_indices, logits_indices,
) )
# TODO # TODO(woosuk): Support sampling with temperature and top_p.
next_token_ids = jnp.argmax(logits, axis=-1) next_token_ids = jnp.argmax(logits, axis=-1)
return next_token_ids, new_kv_caches return next_token_ids, new_kv_caches