diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py index 8e6ad49c1dd2d..f3a8569c3c020 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu_block_table.py @@ -82,12 +82,6 @@ class BlockTables: self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32) - # NOTE(woosuk): Here, we assume that total number of new blocks - # is ALWAYS less than max_num_batched_tokens. - # TODO(woosuk): Rigorously verify that this assumption is correct. - self.new_block_ids = self._make_buffer(self.num_kv_cache_groups, - self.max_num_batched_tokens, - dtype=torch.int32) def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: return CpuGpuBuffer(*args, @@ -119,14 +113,28 @@ class BlockTables: self.overwrite.np[:num_reqs] = overwrite for i in range(self.num_kv_cache_groups): self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i] - self.new_block_ids.np[i, :len(new_block_ids[i])] = new_block_ids[i] + + # NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's + # no clear upper bound on the number of new blocks. + new_block_ids_cpu = torch.empty( + self.num_kv_cache_groups, + max(len(b) for b in new_block_ids), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) + new_block_ids_np = new_block_ids_cpu.numpy() + for i in range(self.num_kv_cache_groups): + new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i] + new_block_ids_gpu = new_block_ids_cpu.to(self.device, + non_blocking=True) _append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)]( self.req_indices.copy_to_gpu(num_reqs), self.cu_num_new_blocks.copy_to_gpu(), self.cu_num_new_blocks.gpu.stride(0), - self.new_block_ids.copy_to_gpu(), - self.new_block_ids.gpu.stride(0), + new_block_ids_gpu, + new_block_ids_gpu.stride(0), self.overwrite.copy_to_gpu(num_reqs), self.block_table_strides, self.buffer_ptrs,