diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 71b4b38fb9d62..bbcc4d59ae1ca 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -163,8 +163,8 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - dtype_btyes = get_dtype_size(self.cache_dtype) - block_size_bytes = (dtype_btyes * self.cache_config.block_size * + dtype_bytes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_bytes * self.cache_config.block_size * num_layers * 2 * head_size * num_kv_heads) num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.