From b15db234ba15546f37a33508c454a66e95f0a155 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 05:43:08 +0000 Subject: [PATCH] Add precompilation step --- vllm/worker/tpu_model_runner.py | 13 ++++++++----- vllm/worker/tpu_worker.py | 4 +++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 67db69c2cdf7a..2bd64005de5cb 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -61,10 +61,12 @@ class TPUModelRunner: params = load_and_format_params(model_dir + "/7b/")["transformer"] self.params = {"params": params} - def warmup_model(self, tpu_caches: List[Tuple[jax.Array, - jax.Array]]) -> None: + def warmup_model( + self, + tpu_caches: List[Tuple[jax.Array, jax.Array]], + ) -> List[Tuple[jax.Array, jax.Array]]: # Prefill - logger.info("Warming up the model...") + logger.info("Compiling the model with different input shapes...") start = time.time() for batch_size in [1]: for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]: @@ -85,7 +87,7 @@ class TPUModelRunner: block_tables, context_lens, prompt_lens, tpu_caches) end = time.time() - logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.") + logger.info(f"Compilation for prefill done in {(end - start):.2f} s.") # Decode start = time.time() @@ -104,7 +106,8 @@ class TPUModelRunner: block_tables, context_lens, prompt_lens, tpu_caches) end = time.time() - logger.info(f"Decode warmup done in {(end - start):.2f} seconds.") + logger.info(f"Compilation for decode done in {(end - start):.2f} s.") + return tpu_caches def _prepare_prompt( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 8c0f2ef7acd6f..46f447b26adb9 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -92,7 +92,9 @@ class TPUWorker(LoraNotSupportedWorkerBase): self._warmup_model() def _warmup_model(self) -> None: - self.model_runner.warmup_model(self.tpu_cache) + # NOTE(woosuk): Because of buffer donation, the reference to the cache + # should be updated after the warmup. + self.tpu_cache = self.model_runner.warmup_model(self.tpu_cache) def get_cache_block_size_bytes(self) -> int: head_size = self.model_config.get_head_size()