This commit is contained in:
Woosuk Kwon 2024-04-19 08:08:25 +00:00
parent 743695f586
commit 84284302d8

View File

@ -12,6 +12,8 @@ from vllm.utils import pad_to_max_length
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
class TPUModelRunner:
@ -35,7 +37,10 @@ class TPUModelRunner:
self.model = None
self.block_size = None
# FIXME
# self.compiled_fn = jax.jit(self._execute_step)
# self.compiled_fn = jax.jit(
# self._execute_step,
# donate_argnums=6,
# )
self.compiled_fn = self._execute_step
def load_model(self) -> None:
@ -96,7 +101,7 @@ class TPUModelRunner:
dtype=jnp.int32)
slot_mapping = _make_array_with_pad(slot_mapping,
max_prompt_len,
pad=0, # FIXME
pad=_PAD_SLOT_ID,
dtype=jnp.int32)
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens