diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0aa0dc14c7480..a6c953ee0eac9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -965,7 +965,9 @@ __global__ void gather_and_maybe_dequant_cache( } }; - for (int pid = split_start; pid < full_blocks_end; ++pid) { + const auto loop_end = + std::min((int64_t)full_blocks_end, block_table_stride - offset); + for (int pid = split_start; pid < loop_end; ++pid) { auto block_id = batch_block_table[pid]; auto block_start_ptr = src_cache + block_id * cache_block_stride; auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; @@ -976,12 +978,15 @@ __global__ void gather_and_maybe_dequant_cache( } if (partial_block_size) { - auto block_id = batch_block_table[full_blocks_end]; - auto block_start_ptr = src_cache + block_id * cache_block_stride; - auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; - for (int eid = 0; eid < partial_block_size; ++eid) { - copy_entry(block_start_ptr + eid * cache_entry_stride, - block_dst_ptr + eid * dst_entry_stride); + if (offset + full_blocks_end < block_table_stride) { + auto block_id = batch_block_table[full_blocks_end]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = + dst + full_blocks_end * block_size * dst_entry_stride; + for (int eid = 0; eid < partial_block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } } } } diff --git a/tests/kernels/test_cache_kernels.py b/tests/kernels/test_cache_kernels.py new file mode 100644 index 0000000000000..b5d66b4ede886 --- /dev/null +++ b/tests/kernels/test_cache_kernels.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for CUDA kernels in cache_kernels.cu.""" + +import pytest +import torch + +try: + from vllm import _custom_ops as ops +except ImportError: + pytest.skip( + "Could not import vllm._custom_ops. (pip install -e .)", allow_module_level=True + ) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device") +def test_gather_cache_oob(): + """ + Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909). + This test constructs a boundary case identified in the issue where + seq_starts causes the block_table offset to read out of bounds. + """ + + batch_size = 1 + block_size = 64 + entry_size = 128 + + block_table = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda") + + # This will result in offset = 128 / block_size = 128 / 64 = 2 + # This will cause the kernel to try to read from + # block_table[0, 2], but its size is only 2. + seq_starts = torch.tensor([128], dtype=torch.int32, device="cuda") + + seq_len = 65 + cu_seq_lens = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda") + + # src_cache: [num_blocks, block_size, entry_size] + num_blocks = 5 + src_cache = torch.randn( + (num_blocks, block_size, entry_size), dtype=torch.float16, device="cuda" + ) + + dst = torch.empty((seq_len, entry_size), dtype=torch.float16, device="cuda") + + scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + # Calling the C++ function gather_and_maybe_dequant_cache + ops.gather_and_maybe_dequant_cache( + src_cache, + dst, + block_table, + cu_seq_lens, + batch_size, + "auto", # kv_cache_dtype + scale, + seq_starts, + ) + + torch.cuda.synchronize() + assert True + + +if __name__ == "__main__": + pytest.main([__file__])