From 3f6288cc89aa4d4743a82f673f594476eb6f0ca9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 08:56:12 +0000 Subject: [PATCH] Fix for binary cache --- vllm/worker/tpu_model_runner.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 73376070b4b3f..72a23a7ba89a0 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -16,6 +16,7 @@ logger = init_logger(__name__) _PAD_SLOT_ID = -1 _MAX_NUM_SEQS = 256 +_MAX_NUM_BLOCKS_PER_SEQ = 8192 // 16 class TPUModelRunner: @@ -41,7 +42,8 @@ class TPUModelRunner: self.block_size = None self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, )) # FIXME(woosuk) - self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32) + self.block_tables = np.zeros((_MAX_NUM_SEQS, _MAX_NUM_BLOCKS_PER_SEQ), + dtype=np.int32) def load_model(self) -> None: from huggingface_hub import snapshot_download @@ -55,6 +57,7 @@ class TPUModelRunner: model_dir = snapshot_download(model_name) params = load_and_format_params(model_dir + "/7b/")["transformer"] self.params = {"params": params} + self.cpu_device = jax.devices("cpu")[0] def warmup_model( self, @@ -91,8 +94,8 @@ class TPUModelRunner: token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) - block_tables = jnp.asarray(self.block_tables[:batch_size], - dtype=jnp.int32) + block_tables = jnp.zeros((batch_size, _MAX_NUM_BLOCKS_PER_SEQ), + dtype=jnp.int32) context_lens = jnp.ones((batch_size, ), dtype=jnp.int32) prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32) @@ -263,6 +266,7 @@ class TPUModelRunner: start = time.time() inputs = self.prepare_input_arrays(seq_group_metadata_list) end = time.time() + # print(inputs[0].shape) # print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms") start = time.time() @@ -273,7 +277,7 @@ class TPUModelRunner: # print(f"compiled_fn: {(end - start) * 1000:.2f} ms") start = time.time() - next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) + next_token_ids = jax.device_put(next_token_ids, self.cpu_device) end = time.time() # print(f"jax.device_put: {(end - start) * 1000:.2f} ms")