From 84284302d8cf179472a258fae333ae781e5c4287 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 19 Apr 2024 08:08:25 +0000 Subject: [PATCH] Minor --- vllm/worker/tpu_model_runner.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index c685dac9ff875..d3da79a676e47 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -12,6 +12,8 @@ from vllm.utils import pad_to_max_length logger = init_logger(__name__) +_PAD_SLOT_ID = -1 + class TPUModelRunner: @@ -35,7 +37,10 @@ class TPUModelRunner: self.model = None self.block_size = None # FIXME - # self.compiled_fn = jax.jit(self._execute_step) + # self.compiled_fn = jax.jit( + # self._execute_step, + # donate_argnums=6, + # ) self.compiled_fn = self._execute_step def load_model(self) -> None: @@ -96,7 +101,7 @@ class TPUModelRunner: dtype=jnp.int32) slot_mapping = _make_array_with_pad(slot_mapping, max_prompt_len, - pad=0, # FIXME + pad=_PAD_SLOT_ID, dtype=jnp.int32) prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32) return input_tokens, input_positions, slot_mapping, None, None, prompt_lens