mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 02:05:01 +08:00
[Bugfix][cache_kernels]: Fix OOB in cache_kernels.cu (#28760)
Signed-off-by: vensen <vensenmu@gmail.com> Signed-off-by: Vensenmu <vensenmu@gmail.com>
This commit is contained in:
parent
a903d59ffa
commit
fb8851f254
@ -965,7 +965,9 @@ __global__ void gather_and_maybe_dequant_cache(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
const auto loop_end =
|
||||||
|
std::min((int64_t)full_blocks_end, block_table_stride - offset);
|
||||||
|
for (int pid = split_start; pid < loop_end; ++pid) {
|
||||||
auto block_id = batch_block_table[pid];
|
auto block_id = batch_block_table[pid];
|
||||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||||
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
|
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
|
||||||
@ -976,12 +978,15 @@ __global__ void gather_and_maybe_dequant_cache(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (partial_block_size) {
|
if (partial_block_size) {
|
||||||
auto block_id = batch_block_table[full_blocks_end];
|
if (offset + full_blocks_end < block_table_stride) {
|
||||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
auto block_id = batch_block_table[full_blocks_end];
|
||||||
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
|
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||||
for (int eid = 0; eid < partial_block_size; ++eid) {
|
auto block_dst_ptr =
|
||||||
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
dst + full_blocks_end * block_size * dst_entry_stride;
|
||||||
block_dst_ptr + eid * dst_entry_stride);
|
for (int eid = 0; eid < partial_block_size; ++eid) {
|
||||||
|
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
||||||
|
block_dst_ptr + eid * dst_entry_stride);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
65
tests/kernels/test_cache_kernels.py
Normal file
65
tests/kernels/test_cache_kernels.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Unit tests for CUDA kernels in cache_kernels.cu."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip(
|
||||||
|
"Could not import vllm._custom_ops. (pip install -e .)", allow_module_level=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device")
|
||||||
|
def test_gather_cache_oob():
|
||||||
|
"""
|
||||||
|
Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909).
|
||||||
|
This test constructs a boundary case identified in the issue where
|
||||||
|
seq_starts causes the block_table offset to read out of bounds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
block_size = 64
|
||||||
|
entry_size = 128
|
||||||
|
|
||||||
|
block_table = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# This will result in offset = 128 / block_size = 128 / 64 = 2
|
||||||
|
# This will cause the kernel to try to read from
|
||||||
|
# block_table[0, 2], but its size is only 2.
|
||||||
|
seq_starts = torch.tensor([128], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
seq_len = 65
|
||||||
|
cu_seq_lens = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# src_cache: [num_blocks, block_size, entry_size]
|
||||||
|
num_blocks = 5
|
||||||
|
src_cache = torch.randn(
|
||||||
|
(num_blocks, block_size, entry_size), dtype=torch.float16, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
dst = torch.empty((seq_len, entry_size), dtype=torch.float16, device="cuda")
|
||||||
|
|
||||||
|
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
# Calling the C++ function gather_and_maybe_dequant_cache
|
||||||
|
ops.gather_and_maybe_dequant_cache(
|
||||||
|
src_cache,
|
||||||
|
dst,
|
||||||
|
block_table,
|
||||||
|
cu_seq_lens,
|
||||||
|
batch_size,
|
||||||
|
"auto", # kv_cache_dtype
|
||||||
|
scale,
|
||||||
|
seq_starts,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Loading…
x
Reference in New Issue
Block a user