diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py index f3a8569c3c020..05d2b6fbe666f 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu_block_table.py @@ -129,7 +129,7 @@ class BlockTables: new_block_ids_gpu = new_block_ids_cpu.to(self.device, non_blocking=True) - _append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)]( + _append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)]( self.req_indices.copy_to_gpu(num_reqs), self.cu_num_new_blocks.copy_to_gpu(), self.cu_num_new_blocks.gpu.stride(0), @@ -148,7 +148,7 @@ class BlockTables: idx_mapping: torch.Tensor, ) -> tuple[torch.Tensor, ...]: batch_size = idx_mapping.shape[0] - _compute_block_tables_kernel[(batch_size, self.num_kv_cache_groups)]( + _compute_block_tables_kernel[(self.num_kv_cache_groups, batch_size)]( idx_mapping, self.buffer_ptrs, self.block_table_ptrs, @@ -167,7 +167,7 @@ class BlockTables: num_reqs = query_start_loc.shape[0] - 1 num_tokens = positions.shape[0] num_groups = self.num_kv_cache_groups - _compute_slot_mappings_kernel[(num_reqs + 1, num_groups)]( + _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)]( num_tokens, self.max_num_batched_tokens, query_start_loc, @@ -200,8 +200,8 @@ def _append_block_ids_kernel( # Constants BLOCK_SIZE: tl.constexpr, ): - batch_idx = tl.program_id(0) - group_id = tl.program_id(1) + group_id = tl.program_id(0) + batch_idx = tl.program_id(1) req_idx = tl.load(req_indices + batch_idx) do_overwrite = tl.load(overwrite + batch_idx) @@ -246,9 +246,9 @@ def _compute_block_tables_kernel( num_blocks_stride, BLOCK_SIZE: tl.constexpr, ): - batch_idx = tl.program_id(0) # kv cache group id - group_id = tl.program_id(1) + group_id = tl.program_id(0) + batch_idx = tl.program_id(1) req_idx = tl.load(batch_idx_to_req_idx + batch_idx) group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride @@ -280,12 +280,12 @@ def _compute_slot_mappings_kernel( PAD_ID: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - req_idx = tl.program_id(0) # kv cache group id - group_id = tl.program_id(1) + group_id = tl.program_id(0) + req_idx = tl.program_id(1) slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride - if req_idx == tl.num_programs(0) - 1: + if req_idx == tl.num_programs(1) - 1: # Pad remaining slots to -1. This is needed for CUDA graphs. for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE)