Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-13 19:13:08 -07:00
parent 9314a83b56
commit caf963f2e9

View File

@ -82,12 +82,6 @@ class BlockTables:
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
self.max_num_reqs + 1,
dtype=torch.int32)
# NOTE(woosuk): Here, we assume that total number of new blocks
# is ALWAYS less than max_num_batched_tokens.
# TODO(woosuk): Rigorously verify that this assumption is correct.
self.new_block_ids = self._make_buffer(self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
@ -119,14 +113,28 @@ class BlockTables:
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
self.new_block_ids.np[i, :len(new_block_ids[i])] = new_block_ids[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound on the number of new blocks.
new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(b) for b in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = new_block_ids_cpu.to(self.device,
non_blocking=True)
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
self.new_block_ids.copy_to_gpu(),
self.new_block_ids.gpu.stride(0),
new_block_ids_gpu,
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.buffer_ptrs,