diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index d69acb4ac16bf..fb5f25639be6a 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -64,7 +64,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): self.model_runner.load_model() def determine_num_available_blocks(self) -> Tuple[int, int]: - num_tpu_blocks = 1000 + num_tpu_blocks = 2000 return num_tpu_blocks, 0 def initialize_cache( @@ -86,6 +86,11 @@ class TPUWorker(LoraNotSupportedWorkerBase): dtype=dtype) for _ in range(num_layers) ] self.model_runner.block_size = self.block_size + self._warmup_model() + + def _warmup_model(self) -> None: + # self.model_runner.warmup_model(self.tpu_cache) + pass def get_cache_block_size_bytes(self) -> int: head_size = self.model_config.get_head_size()