diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index aa611ac0d336..24727713ff88 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -50,14 +50,15 @@ def main(args: argparse.Namespace): block_size=args.block_size, ) sampling_params_dict = { - 'n': 1, - 'temperature': 0.0, + 'n': args.n, + 'temperature': 0.0 if args.use_beam_search else 1.0, 'top_p': 1.0, - 'use_beam_search': False, + 'use_beam_search': args.use_beam_search, 'stop_token_ids': set(), 'max_num_steps': args.output_len, } sampling_params = SamplingParams.from_dict(sampling_params_dict) + print(sampling_params) input_token_ids = [0] * args.input_len def profile_step(profile=False): @@ -93,6 +94,8 @@ if __name__ == '__main__': parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--batch-size', type=int, default=8) + parser.add_argument('--n', type=int, default=1) + parser.add_argument('--use-beam-search', action='store_true') args = parser.parse_args() args.max_num_batched_tokens = max( args.max_num_batched_tokens, args.batch_size * args.input_len) diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 3b53f34f4c73..1e358c7e5278 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -185,9 +185,10 @@ def _sample_from_generation_tokens( vocab_size = logprobs.size(-1) beam_width = len(seq_ids) _, topk_ids = torch.topk(logprobs.flatten(), beam_width) - seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist() + topk_ids = topk_ids.tolist() + seq_idx = [i // vocab_size for i in topk_ids] beam_seq_ids = [seq_ids[i] for i in seq_idx] - token_ids = (topk_ids % vocab_size).tolist() + token_ids = [i % vocab_size for i in topk_ids] beam_outputs: Dict[int, Tuple[int, int]] = {} outstanding_beams: List[Tuple[int, int]] = [] diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 164b2a2a60fd..addde3883b69 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -120,24 +120,8 @@ class CacheEngine: def swap_out(self, src_to_dst: Dict[int, int]) -> None: self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) - def _copy( - self, - src: List[KVCache], - dst: List[KVCache], - src_to_dsts: Dict[int, List[int]], - ) -> None: - with torch.cuda.stream(self.cache_stream): - for i in range(self.num_layers): - src_key_cache, src_value_cache = src[i] - dst_key_cache, dst_value_cache = dst[i] - # Copy the key blocks. - cache_ops.copy_blocks( - src_key_cache, dst_key_cache, src_to_dsts) - # Copy the value blocks. - cache_ops.copy_blocks( - src_value_cache, dst_value_cache, src_to_dsts) - event = self.events[i] - event.record(stream=self.cache_stream) - def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: - self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts) + key_caches = [key_cache for key_cache, _ in self.gpu_cache] + value_caches = [value_cache for _, value_cache in self.gpu_cache] + # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU. + cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) diff --git a/csrc/cache.cpp b/csrc/cache.cpp index fcf8b69fe860..907736a981c9 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -9,8 +9,8 @@ void swap_blocks( const std::map& block_mapping); void copy_blocks( - torch::Tensor& src, - torch::Tensor& dst, + std::vector& key_caches, + std::vector& value_caches, const std::map>& block_mapping); void reshape_and_cache( diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8b5537c47229..3a34ba578980 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -43,33 +43,93 @@ void swap_blocks( } } +namespace cacheflow { + +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_kernel( + int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int* __restrict__ block_mapping, + const int numel_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + + scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); + int src_block_number = block_mapping[2 * pair_idx]; + int dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int src_block_offset = src_block_number * numel_per_block; + const int dst_block_offset = dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int src_offset = src_block_offset + i; + int dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int src_offset = src_block_offset + i; + int dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +} // namespace cacheflow + void copy_blocks( - torch::Tensor& src, - torch::Tensor& dst, + std::vector& key_caches, + std::vector& value_caches, const std::map>& block_mapping) { - torch::Device src_device = src.device(); - torch::Device dst_device = dst.device(); - assert(src_device.is_cuda() && dst_device.is_cuda()); - cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice; + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + torch::Device cache_device = key_caches[0].device(); + TORCH_CHECK(cache_device.is_cuda()); - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); - - const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Create data structures for the kernel. + // Create an array of pointers to the key and value caches. + int64_t key_cache_ptrs[num_layers]; + int64_t value_cache_ptrs[num_layers]; + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); + } + // Create block mapping array. + std::vector block_mapping_vec; for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - for (int64_t dst_block_number : pair.second) { - int64_t src_offset = src_block_number * block_size_in_bytes; - int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync( - dst_ptr + dst_offset, - src_ptr + src_offset, - block_size_in_bytes, - memcpy_type, - stream); + int src_block_number = pair.first; + for (int dst_block_number : pair.second) { + block_mapping_vec.push_back(src_block_number); + block_mapping_vec.push_back(dst_block_number); } } + int* block_mapping_array = block_mapping_vec.data(); + int num_pairs = block_mapping_vec.size() / 2; + + // Move the data structures to the GPU. + // NOTE: This synchronizes the CPU and GPU. + torch::Tensor key_cache_ptrs_tensor = torch::from_blob( + key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); + torch::Tensor value_cache_ptrs_tensor = torch::from_blob( + value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); + torch::Tensor block_mapping_tensor = torch::from_blob( + block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device); + + // Launch the kernel. + const int numel_per_block = key_caches[0][0].numel(); + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, numel_per_block)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + cacheflow::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping_tensor.data_ptr(), + numel_per_block); + })); } namespace cacheflow { diff --git a/tests/kernels/cache.py b/tests/kernels/cache.py index d6b1c3d2dd48..89f14cca82a2 100644 --- a/tests/kernels/cache.py +++ b/tests/kernels/cache.py @@ -5,6 +5,61 @@ import torch from cacheflow import cache_ops +def test_copy_blocks( + num_mappings: int, + num_layers: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + # Generate random block mappings. + src_blocks = random.sample(range(num_blocks), num_mappings) + remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remainig_blocks, num_mappings) + block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)} + + # Create the KV cache. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches = [] + for _ in range(num_layers): + key_cache = torch.randn( + size=key_cache_shape, dtype=dtype, device='cuda') + key_caches.append(key_cache) + cloned_key_caches = [] + for key_cache in key_caches: + cloned_key_caches.append(key_cache.clone()) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + value_cache = torch.randn( + size=value_cache_shape, dtype=dtype, device='cuda') + value_caches.append(value_cache) + cloned_value_caches = [] + for value_cache in value_caches: + cloned_value_caches.append(value_cache.clone()) + + # Call the copy blocks kernel. + cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + # Reference implementation. + for src, dsts in block_mapping.items(): + for dst in dsts: + for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): + cloned_key_cache[dst] = cloned_key_cache[src] + for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): + cloned_value_cache[dst] = cloned_value_cache[src] + + # Compare the results. + for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): + assert torch.allclose(key_cache, cloned_key_cache) + for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): + assert torch.allclose(value_cache, cloned_value_cache) + + def test_reshape_and_cache( num_tokens: int, num_heads: int, @@ -46,6 +101,9 @@ def test_reshape_and_cache( @torch.inference_mode() def test_cache() -> None: + test_copy_blocks( + num_mappings=23, num_layers=7, num_heads=17, head_size=16, + block_size=8, num_blocks=1024, dtype=torch.half) test_reshape_and_cache( num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, dtype=torch.half)