From 223e17424cc26581b9717dbe51f825568090b25a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 15 Jan 2025 01:24:45 -0800 Subject: [PATCH] working Signed-off-by: Woosuk Kwon --- CMakeLists.txt | 1 + csrc/block_table.cu | 92 ++++++++++++++++ csrc/ops.h | 12 +++ csrc/torch_bindings.cpp | 13 +++ vllm/_custom_ops.py | 26 +++++ vllm/v1/worker/block_table.py | 1 + vllm/v1/worker/gpu_block_table.py | 174 ++++++++++++++++++++++++++++++ vllm/v1/worker/gpu_input_batch.py | 4 +- 8 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 csrc/block_table.cu create mode 100644 vllm/v1/worker/gpu_block_table.py diff --git a/CMakeLists.txt b/CMakeLists.txt index f4b9c3ec9c14f..74c416ff1d8b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -187,6 +187,7 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") set(VLLM_EXT_SRC "csrc/cache_kernels.cu" + "csrc/block_table.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" "csrc/pos_encoding_kernels.cu" diff --git a/csrc/block_table.cu b/csrc/block_table.cu new file mode 100644 index 0000000000000..29eca5cb806ba --- /dev/null +++ b/csrc/block_table.cu @@ -0,0 +1,92 @@ +#include + +#include +#include + +namespace vllm { +__global__ void append_kernel(const int* __restrict__ row_indices, + const int* __restrict__ cu_num_appends, + const int* __restrict__ block_ids, + int* __restrict__ block_table, + int max_num_blocks_per_row) { + int bid = blockIdx.x; + int tgt_row = row_indices[2 * bid]; + int tgt_offset = row_indices[2 * bid + 1]; + + int start = cu_num_appends[bid]; + int end = cu_num_appends[bid + 1]; + int length = end - start; + int tid = threadIdx.x; + int64_t offset = tgt_row * max_num_blocks_per_row + tgt_offset; + for (int i = tid; i < length; i += blockDim.x) { + block_table[offset + i] = block_ids[start + i]; + } +} + +__global__ void move_kernel(const int* __restrict__ src_dst_n, + int* __restrict__ block_table, + int max_num_blocks_per_row) { + int bid = blockIdx.x; + int src_row = src_dst_n[3 * bid]; + int tgt_row = src_dst_n[3 * bid + 1]; + int num_blocks = src_dst_n[3 * bid + 2]; + + int tid = threadIdx.x; + for (int i = tid; i < num_blocks; i += blockDim.x) { + block_table[tgt_row * max_num_blocks_per_row + i] = + block_table[src_row * max_num_blocks_per_row + i]; + } +} +} // namespace vllm + +void block_table_appends( + torch::Tensor& append_row_indices, + torch::Tensor& append_row_indices_cpu, + torch::Tensor& append_cumsums, + torch::Tensor& append_cumsums_cpu, + torch::Tensor& append_block_ids, + torch::Tensor& append_block_ids_cpu, + torch::Tensor& block_table, + int64_t num_appends, + int64_t total_num_append_blocks) { + int* append_row_indices_ptr = append_row_indices.data_ptr(); + const int* append_row_indices_cpu_ptr = append_row_indices_cpu.data_ptr(); + int* append_cumsums_ptr = append_cumsums.data_ptr(); + const int* append_cumsums_cpu_ptr = append_cumsums_cpu.data_ptr(); + int* append_block_ids_ptr = append_block_ids.data_ptr(); + const int* append_block_ids_cpu_ptr = append_block_ids_cpu.data_ptr(); + int* block_table_ptr = block_table.data_ptr(); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaMemcpyAsync(append_row_indices_ptr, append_row_indices_cpu_ptr, + num_appends * 2 * sizeof(int), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(append_cumsums_ptr, append_cumsums_cpu_ptr, + (num_appends + 1) * sizeof(int), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(append_block_ids_ptr, append_block_ids_cpu_ptr, + total_num_append_blocks * sizeof(int), cudaMemcpyHostToDevice, stream); + + int64_t max_num_blocks_per_row = block_table.size(1); + vllm::append_kernel<<>>( + append_row_indices_ptr, append_cumsums_ptr, append_block_ids_ptr, + block_table_ptr, max_num_blocks_per_row); +} + +void block_table_moves( + torch::Tensor& src_dst_n, + torch::Tensor& src_dst_n_cpu, + torch::Tensor& block_table, + int64_t num_moves) { + int* src_dst_n_ptr = src_dst_n.data_ptr(); + const int* src_dst_n_cpu_ptr = src_dst_n_cpu.data_ptr(); + int* block_table_ptr = block_table.data_ptr(); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaMemcpyAsync(src_dst_n_ptr, src_dst_n_cpu_ptr, + num_moves * 3 * sizeof(int), cudaMemcpyHostToDevice, stream); + + int64_t max_num_blocks_per_row = block_table.size(1); + vllm::move_kernel<<>>( + src_dst_n_ptr, block_table_ptr, max_num_blocks_per_row); +} diff --git a/csrc/ops.h b/csrc/ops.h index 5a194a0dd3654..459aa8b64773a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -117,6 +117,18 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); +void block_table_appends(torch::Tensor& append_row_indices, + torch::Tensor& append_row_indices_cpu, + torch::Tensor& append_cumsums, + torch::Tensor& append_cumsums_cpu, + torch::Tensor& append_block_ids, + torch::Tensor& append_block_ids_cpu, + torch::Tensor& block_table, int64_t num_appends, + int64_t total_num_append_blocks); + +void block_table_moves(torch::Tensor& src_dst_n, torch::Tensor& src_dst_n_cpu, + torch::Tensor& block_table, int64_t num_moves); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index fb53d122487d3..6e07531f74f3a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -101,6 +101,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); + ops.def( + "block_table_appends(Tensor append_row_indices, " + "Tensor append_row_indices_cpu, Tensor append_cumsums, " + "Tensor append_cumsums_cpu, Tensor append_block_ids, " + "Tensor append_block_ids_cpu, Tensor! block_table, int num_appends, " + "int total_num_append_blocks) -> ()"); + ops.impl("block_table_appends", torch::kCUDA, &block_table_appends); + + ops.def( + "block_table_moves(Tensor src_dst_n, Tensor src_dst_n_cpu, " + "Tensor! block_table, int num_moves) -> ()"); + ops.impl("block_table_moves", torch::kCUDA, &block_table_moves); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d04cbbc0a9eed..5f760e2e0efe0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -187,6 +187,32 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, block_table_bound) +def block_table_appends( + append_row_indices: torch.Tensor, + append_row_indices_cpu: torch.Tensor, + append_cumsums: torch.Tensor, + append_cumsums_cpu: torch.Tensor, + append_block_ids: torch.Tensor, + append_block_ids_cpu: torch.Tensor, + block_table: torch.Tensor, + num_appends: int, + total_num_append_blocks: int, +) -> None: + torch.ops._C.block_table_appends.default( + append_row_indices, append_row_indices_cpu, append_cumsums, + append_cumsums_cpu, append_block_ids, append_block_ids_cpu, + block_table, num_appends, total_num_append_blocks) + + +def block_table_moves( + src_dst_n: torch.Tensor, + src_dst_n_cpu: torch.Tensor, + block_table: torch.Tensor, + num_moves: int, +) -> None: + torch.ops._C.block_table_moves.default(src_dst_n, src_dst_n_cpu, block_table, + num_moves) + # fused quant layer norm ops def rms_norm_dynamic_per_token_quant( diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 26a2084b131fa..9ec1df7758797 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -9,6 +9,7 @@ logger = init_logger(__name__) class BlockTable: + """Device-agnostic block table for storing block IDs for each request.""" def __init__( self, diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py new file mode 100644 index 0000000000000..1690ad2cb730b --- /dev/null +++ b/vllm/v1/worker/gpu_block_table.py @@ -0,0 +1,174 @@ +from typing import List + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class GPUBlockTable: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.pin_memory = pin_memory + self.device = device + + self.block_table = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=False, + ) + self.block_table_np = self.block_table_cpu.numpy() + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + # Append. + self.append_cnt = 0 + self.append_row_indices = torch.zeros( + (max_num_reqs, 2), + dtype=torch.int32, + device=self.device, + ) + self.append_row_indices_cpu = torch.zeros_like( + self.append_row_indices, + device="cpu", + pin_memory=pin_memory, + ) + self.append_row_indices_np = self.append_row_indices_cpu.numpy() + self.append_cumsums = torch.zeros( + (max_num_reqs + 1,), + dtype=torch.int32, + device=self.device, + ) + self.append_cumsums_cpu = torch.zeros_like( + self.append_cumsums, + device="cpu", + pin_memory=pin_memory, + ) + self.append_cumsums_np = self.append_cumsums_cpu.numpy() + self.append_data = torch.zeros( + (max_num_reqs * max_num_blocks_per_req,), + dtype=torch.int32, + device=self.device, + ) + self.append_data_cpu = torch.zeros_like( + self.append_data, + device="cpu", + pin_memory=pin_memory, + ) + self.append_data_np = self.append_data_cpu.numpy() + + # Move. + self.move_cnt = 0 + self.move_src_dst = torch.zeros( + (max_num_reqs, 3), # (src, dst, num_blocks) + dtype=torch.int32, + device=self.device, + ) + self.move_src_dst_cpu = torch.zeros_like( + self.move_src_dst, + device="cpu", + pin_memory=pin_memory, + ) + self.move_src_dst_np = self.move_src_dst_cpu.numpy() + + def append_row( + self, + row_idx: int, + start: int, + block_ids: List[int], + ) -> None: + num_blocks = len(block_ids) + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.num_blocks_per_row[row_idx] = start + num_blocks + + self.append_row_indices_np[self.append_cnt, 0] = row_idx + self.append_row_indices_np[self.append_cnt, 1] = start + append_start = self.append_cumsums_np[self.append_cnt] + append_end = append_start + num_blocks + self.append_cumsums_np[self.append_cnt + 1] = append_end + self.append_data_np[append_start:append_end] = block_ids + self.append_cnt += 1 + + def add_row(self, row_idx: int, block_ids: List[int]) -> None: + self.append_row(row_idx, 0, block_ids) + + def move_row(self, src: int, tgt: int) -> None: + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks + + self.move_src_dst_np[self.move_cnt, 0] = src + self.move_src_dst_np[self.move_cnt, 1] = tgt + self.move_src_dst_np[self.move_cnt, 2] = num_blocks + self.move_cnt += 1 + + def commit(self, num_reqs: int) -> None: + if self.append_cnt > 0: + total_num_append_blocks = self.append_cumsums_np[self.append_cnt] + ops.block_table_appends( + self.append_row_indices, + self.append_row_indices_cpu, + self.append_cumsums, + self.append_cumsums_cpu, + self.append_data, + self.append_data_cpu, + self.block_table, + self.append_cnt, + total_num_append_blocks, + ) + if self.move_cnt > 0: + ops.block_table_moves( + self.move_src_dst, + self.move_src_dst_cpu, + self.block_table, + self.move_cnt, + ) + self.append_cnt = 0 + self.move_cnt = 0 + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + + self.append_cnt = 0 + self.append_row_indices.fill_(0) + self.append_row_indices_cpu.fill_(0) + self.append_cumsums.fill_(0) + self.append_cumsums_cpu.fill_(0) + self.append_data.fill_(0) + self.append_data_cpu.fill_(0) + + self.move_cnt = 0 + self.move_src_dst.fill_(0) + self.move_src_dst_cpu.fill_(0) + + def get_device_tensor(self) -> torch.Tensor: + """Ruturns the device tensor of the block table.""" + return self.block_table + + def get_cpu_tensor(self) -> torch.Tensor: + """Returns the CPU tensor of the block table.""" + return self.block_table_cpu + + def get_numpy_array(self) -> np.ndarray: + """Returns the numpy array of the block table.""" + return self.block_table_np diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22f0..8ab10fb41fb15 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -9,7 +9,7 @@ import torch from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.gpu_block_table import GPUBlockTable if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -72,7 +72,7 @@ class InputBatch: self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Block table. - self.block_table = BlockTable( + self.block_table = GPUBlockTable( max_num_reqs=max_num_reqs, max_model_len=max_model_len, max_num_blocks_per_req=max_num_blocks_per_req,