From f5c8628fdc78bf9ca70206ef41175030fb67e870 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Jun 2024 13:42:40 -0700 Subject: [PATCH] [Bugfix][TPU] Fix CPU cache allocation (#5869) --- vllm/attention/backends/pallas.py | 5 ++--- vllm/worker/tpu_worker.py | 8 ++++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 121ca9ec45205..5dec11e2eede7 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -37,11 +37,10 @@ class PallasAttentionBackend(AttentionBackend): ) -> None: src_k_cache, src_v_cache = src_kv_cache dst_k_cache, dst_v_cache = dst_kv_cache + src_indices, dst_indices = src_to_dst + device = dst_k_cache.device torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True) - - device = dst_k_cache.device - src_indices, dst_indices = src_to_dst dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device) dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index c85bf6892fb28..28f460c31aa9b 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -156,14 +156,18 @@ class TPUWorker(LoraNotSupportedWorkerBase): self.tpu_cache = [] tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( num_gpu_blocks, self.block_size, num_kv_heads, head_size) + cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_cpu_blocks, self.block_size, num_kv_heads, head_size) for _ in range(num_layers): tpu_k_cache = torch.zeros(tpu_cache_shape, dtype=dtype, device=self.device) tpu_v_cache = torch.zeros_like(tpu_k_cache) self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) - cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu") - cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu") + cpu_k_cache = torch.zeros(cpu_cache_shape, + dtype=dtype, + device="cpu") + cpu_v_cache = torch.zeros_like(cpu_k_cache) self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) self._warmup_model()