diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py index c8c6f9615a464..a0d39c511ace6 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu_block_table.py @@ -37,10 +37,6 @@ class BlockTables: self.block_tables: list[torch.Tensor] = [] # [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks] self.block_table_buffers: list[torch.Tensor] = [] - # [num_kv_cache_groups, max_num_reqs] - self.num_blocks: list[torch.Tensor] = [] - # [num_kv_cache_groups, max_num_tokens] - self.slot_mappings: list[torch.Tensor] = [] for i in range(self.num_kv_cache_groups): block_size = self.block_sizes[i] max_num_blocks = cdiv(self.max_model_len, block_size) @@ -61,44 +57,40 @@ class BlockTables: ) self.block_table_buffers.append(block_table_buffer) - num_blocks = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.num_blocks.append(num_blocks) - - slot_mapping = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, - device=self.device) - self.slot_mappings.append(slot_mapping) - self.block_table_ptrs = self._make_ptr_tensor(self.block_tables) self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers) self.block_table_strides = torch.tensor( [b.stride(0) for b in self.block_tables], dtype=torch.int64, device=self.device) - self.num_blocks_ptrs = self._make_ptr_tensor(self.num_blocks) self.block_sizes_tensor = torch.tensor(self.block_sizes, dtype=torch.int32, device=self.device) - self.slot_mapping_ptrs = self._make_ptr_tensor(self.slot_mappings) + self.num_blocks = torch.zeros(self.num_kv_cache_groups, + self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.slot_mappings = torch.zeros(self.num_kv_cache_groups, + self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) # Misc buffers. - self.req_indices = self._make_buffer(self.max_num_reqs, torch.int32) - self.overwrite = self._make_buffer(self.max_num_reqs, torch.bool) - self.cu_num_new_blocks: list[CpuGpuBuffer] = [] - self.new_block_ids: list[CpuGpuBuffer] = [] - for i in range(self.num_kv_cache_groups): - self.cu_num_new_blocks.append( - self._make_buffer(self.max_num_reqs + 1, torch.int32)) - # NOTE(woosuk): Here, we assume that total number of new blocks - # is ALWAYS less than max_num_batched_tokens. - # TODO(woosuk): Rigorously verify that this assumption is correct. - self.new_block_ids.append( - self._make_buffer(self.max_num_batched_tokens, torch.int32)) + self.req_indices = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool) + self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups, + self.max_num_reqs + 1, + dtype=torch.int32) + # NOTE(woosuk): Here, we assume that total number of new blocks + # is ALWAYS less than max_num_batched_tokens. + # TODO(woosuk): Rigorously verify that this assumption is correct. + self.new_block_ids = self._make_buffer(self.num_kv_cache_groups, + self.max_num_batched_tokens, + dtype=torch.int32) - def _make_buffer(self, n: int, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(n, + def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device) @@ -126,24 +118,21 @@ class BlockTables: self.req_indices.np[:num_reqs] = req_indices self.overwrite.np[:num_reqs] = overwrite for i in range(self.num_kv_cache_groups): - self.cu_num_new_blocks[i].np[:num_reqs + 1] = cu_num_new_blocks[i] + self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i] n = len(new_block_ids[i]) - self.new_block_ids[i].np[:n] = new_block_ids[i] + self.new_block_ids.np[i, :n] = new_block_ids[i] - cu_num_new_blocks_ptrs = self._make_ptr_tensor( - [x.copy_to_gpu(num_reqs + 1) for x in self.cu_num_new_blocks]) - new_block_ids_ptrs = self._make_ptr_tensor([ - x.copy_to_gpu(len(new_block_ids[i])) - for i, x in enumerate(self.new_block_ids) - ]) _append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)]( self.req_indices.copy_to_gpu(num_reqs), - cu_num_new_blocks_ptrs, - new_block_ids_ptrs, + self.cu_num_new_blocks.copy_to_gpu(), + self.cu_num_new_blocks.gpu.stride(0), + self.new_block_ids.copy_to_gpu(), + self.new_block_ids.gpu.stride(0), self.overwrite.copy_to_gpu(num_reqs), self.block_table_strides, self.buffer_ptrs, - self.num_blocks_ptrs, + self.num_blocks, + self.num_blocks.stride(0), BLOCK_SIZE=1024, ) @@ -157,7 +146,8 @@ class BlockTables: self.buffer_ptrs, self.block_table_ptrs, self.block_table_strides, - self.num_blocks_ptrs, + self.num_blocks, + self.num_blocks.stride(0), BLOCK_SIZE=1024, ) return tuple(b[:batch_size] for b in self.block_tables) @@ -178,7 +168,8 @@ class BlockTables: self.block_table_ptrs, self.block_table_strides, self.block_sizes_tensor, - self.slot_mapping_ptrs, + self.slot_mappings, + self.slot_mappings.stride(0), PAD_ID=PAD_SLOT_ID, BLOCK_SIZE=1024, ) @@ -189,13 +180,16 @@ class BlockTables: def _append_block_ids_kernel( # Inputs req_indices, # [num_reqs] - cu_num_new_block_ptrs, # [num_kv_cache_groups, num_reqs + 1] - new_block_id_ptrs, # [num_kv_cache_groups, num_new_blocks] + cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1] + cu_num_new_blocks_stride, + new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks] + new_block_ids_stride, overwrite, # [num_reqs] block_table_strides, # [num_kv_cache_groups] # Outputs block_table_buffer_ptrs, # [num_kv_cache_groups] - num_block_ptrs, # [num_kv_cache_groups] + num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] + num_blocks_stride, # Constants BLOCK_SIZE: tl.constexpr, ): @@ -204,19 +198,19 @@ def _append_block_ids_kernel( req_idx = tl.load(req_indices + batch_idx) do_overwrite = tl.load(overwrite + batch_idx) - cu_num_new_blocks_ptr = _load_ptr(cu_num_new_block_ptrs + group_id, - tl.int32) - start_idx = tl.load(cu_num_new_blocks_ptr + batch_idx) - end_idx = tl.load(cu_num_new_blocks_ptr + batch_idx + 1) + group_new_blocks_ptr = (cu_num_new_blocks_ptr + + group_id * cu_num_new_blocks_stride) + start_idx = tl.load(group_new_blocks_ptr + batch_idx) + end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1) num_new_blocks = end_idx - start_idx - num_blocks_ptr = _load_ptr(num_block_ptrs + group_id, tl.int32) + group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride if do_overwrite: dst_start_idx = 0 else: - dst_start_idx = tl.load(num_blocks_ptr + req_idx) + dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) dst_end_idx = dst_start_idx + num_new_blocks - tl.store(num_blocks_ptr + req_idx, dst_end_idx) + tl.store(group_num_blocks_ptr + req_idx, dst_end_idx) # Destination block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id, @@ -224,10 +218,11 @@ def _append_block_ids_kernel( block_table_stride = tl.load(block_table_strides + group_id) buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride - new_block_ids_ptr = _load_ptr(new_block_id_ptrs + group_id, tl.int32) + group_new_block_ids_ptr = (new_block_ids_ptr + + group_id * new_block_ids_stride) for i in tl.range(0, num_new_blocks, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) - block_ids = tl.load(new_block_ids_ptr + start_idx + offset, + block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks) tl.store(buffer_row_ptr + dst_start_idx + offset, block_ids, @@ -240,7 +235,8 @@ def _compute_block_tables_kernel( src_block_table_ptrs, # [num_kv_cache_groups] dst_block_table_ptrs, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups] - num_blocks_ptrs, # [num_kv_cache_groups] + num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] + num_blocks_stride, BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) @@ -248,8 +244,8 @@ def _compute_block_tables_kernel( group_id = tl.program_id(1) req_idx = tl.load(batch_idx_to_req_idx + batch_idx) - num_blocks_ptr = _load_ptr(num_blocks_ptrs + group_id, tl.int32) - num_blocks = tl.load(num_blocks_ptr + req_idx) + group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride + num_blocks = tl.load(group_num_blocks_ptr + req_idx) stride = tl.load(block_table_strides + group_id) src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32) @@ -272,14 +268,15 @@ def _compute_slot_mappings_kernel( block_table_ptrs, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups] page_sizes, # [num_kv_cache_groups] - slot_mapping_ptrs, # [num_kv_cache_groups] + slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] + slot_mappings_stride, PAD_ID: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) # kv cache group id group_id = tl.program_id(1) - slot_mapping_ptr = _load_ptr(slot_mapping_ptrs + group_id, tl.int64) + slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride if req_idx == tl.num_programs(0) - 1: # Pad remaining slots to -1. This is needed for CUDA graphs. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cdd15bd2cd22f..63d5d9d554508 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -659,7 +659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc, self.positions[:total_num_scheduled_tokens]) + query_start_loc, self.positions.gpu[:total_num_scheduled_tokens]) # Used in the below loop. query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index 2c276e0d93730..aad6cd5e5345a 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -109,8 +109,6 @@ class RequestState: self.req_id_to_index: dict[str, int] = {} self.index_to_req_id: dict[int, str] = {} self.free_indices = list(range(max_num_cached_reqs)) - # Used to construct the input batch. - self._add_scalar_attr("idx_mapping", torch.int32) # Request states. self.req_data: dict[int, RequestData] = {}