Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-13 19:17:40 -07:00
parent caf963f2e9
commit 5c133fc860

View File

@ -129,7 +129,7 @@ class BlockTables:
new_block_ids_gpu = new_block_ids_cpu.to(self.device,
non_blocking=True)
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
@ -148,7 +148,7 @@ class BlockTables:
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
batch_size = idx_mapping.shape[0]
_compute_block_tables_kernel[(batch_size, self.num_kv_cache_groups)](
_compute_block_tables_kernel[(self.num_kv_cache_groups, batch_size)](
idx_mapping,
self.buffer_ptrs,
self.block_table_ptrs,
@ -167,7 +167,7 @@ class BlockTables:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
@ -200,8 +200,8 @@ def _append_block_ids_kernel(
# Constants
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
group_id = tl.program_id(1)
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
@ -246,9 +246,9 @@ def _compute_block_tables_kernel(
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
# kv cache group id
group_id = tl.program_id(1)
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
@ -280,12 +280,12 @@ def _compute_slot_mappings_kernel(
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# kv cache group id
group_id = tl.program_id(1)
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(0) - 1:
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)