diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py index 1690ad2cb730b..41bc5e09b96eb 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu_block_table.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Set import numpy as np import torch @@ -39,8 +39,12 @@ class GPUBlockTable: self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - # Append. - self.append_cnt = 0 + self.block_table_diff_np = np.zeros( + (max_num_reqs, 2), + dtype=np.int32, + ) + self.diff_rows: Set[int] = set() + self.append_row_indices = torch.zeros( (max_num_reqs, 2), dtype=torch.int32, @@ -75,20 +79,6 @@ class GPUBlockTable: ) self.append_data_np = self.append_data_cpu.numpy() - # Move. - self.move_cnt = 0 - self.move_src_dst = torch.zeros( - (max_num_reqs, 3), # (src, dst, num_blocks) - dtype=torch.int32, - device=self.device, - ) - self.move_src_dst_cpu = torch.zeros_like( - self.move_src_dst, - device="cpu", - pin_memory=pin_memory, - ) - self.move_src_dst_np = self.move_src_dst_cpu.numpy() - def append_row( self, row_idx: int, @@ -99,13 +89,9 @@ class GPUBlockTable: self.block_table_np[row_idx, start:start + num_blocks] = block_ids self.num_blocks_per_row[row_idx] = start + num_blocks - self.append_row_indices_np[self.append_cnt, 0] = row_idx - self.append_row_indices_np[self.append_cnt, 1] = start - append_start = self.append_cumsums_np[self.append_cnt] - append_end = append_start + num_blocks - self.append_cumsums_np[self.append_cnt + 1] = append_end - self.append_data_np[append_start:append_end] = block_ids - self.append_cnt += 1 + self.block_table_diff_np[row_idx, 0] = start + self.block_table_diff_np[row_idx, 1] = num_blocks + self.diff_rows.add(row_idx) def add_row(self, row_idx: int, block_ids: List[int]) -> None: self.append_row(row_idx, 0, block_ids) @@ -116,40 +102,49 @@ class GPUBlockTable: src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks - self.move_src_dst_np[self.move_cnt, 0] = src - self.move_src_dst_np[self.move_cnt, 1] = tgt - self.move_src_dst_np[self.move_cnt, 2] = num_blocks - self.move_cnt += 1 + self.block_table_diff_np[tgt, 0] = 0 + self.block_table_diff_np[tgt, 1] = num_blocks + self.diff_rows.discard(src) + self.diff_rows.add(tgt) def commit(self, num_reqs: int) -> None: - if self.append_cnt > 0: - total_num_append_blocks = self.append_cumsums_np[self.append_cnt] - ops.block_table_appends( - self.append_row_indices, - self.append_row_indices_cpu, - self.append_cumsums, - self.append_cumsums_cpu, - self.append_data, - self.append_data_cpu, - self.block_table, - self.append_cnt, - total_num_append_blocks, - ) - if self.move_cnt > 0: - ops.block_table_moves( - self.move_src_dst, - self.move_src_dst_cpu, - self.block_table, - self.move_cnt, - ) - self.append_cnt = 0 - self.move_cnt = 0 + if not self.diff_rows: + return + + cu_end = 0 + self.append_cumsums_np[0] = 0 + for i, row_idx in enumerate(self.diff_rows): + start, num_blocks = self.block_table_diff_np[row_idx] + assert num_blocks > 0 + + self.append_row_indices_np[i, 0] = row_idx + self.append_row_indices_np[i, 1] = start + cu_start = self.append_cumsums_np[i] + cu_end = cu_start + num_blocks + self.append_cumsums_np[i + 1] = cu_end + self.append_data_np[cu_start:cu_end] = self.block_table_np[ + row_idx, start:start + num_blocks] + + ops.block_table_appends( + self.append_row_indices, + self.append_row_indices_cpu, + self.append_cumsums, + self.append_cumsums_cpu, + self.append_data, + self.append_data_cpu, + self.block_table, + len(self.diff_rows), + cu_end, + ) + self.diff_rows.clear() def clear(self) -> None: self.block_table.fill_(0) self.block_table_cpu.fill_(0) - self.append_cnt = 0 + self.diff_rows.clear() + self.block_table_diff_np.fill(0) + self.append_row_indices.fill_(0) self.append_row_indices_cpu.fill_(0) self.append_cumsums.fill_(0) @@ -157,10 +152,6 @@ class GPUBlockTable: self.append_data.fill_(0) self.append_data_cpu.fill_(0) - self.move_cnt = 0 - self.move_src_dst.fill_(0) - self.move_src_dst_cpu.fill_(0) - def get_device_tensor(self) -> torch.Tensor: """Ruturns the device tensor of the block table.""" return self.block_table