mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 23:51:03 +08:00
[Misc] Remove unused custom ops copy_blocks and copy_blocks_mla (#30967)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
dd424571c8
commit
4ed11105d7
10
csrc/cache.h
10
csrc/cache.h
@ -9,16 +9,6 @@
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
|
||||
@ -119,94 +119,6 @@ __global__ void copy_blocks_mla_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
torch::Device cache_device = key_caches[0].device();
|
||||
TORCH_CHECK(cache_device.is_cuda());
|
||||
|
||||
// Create data structures for the kernel.
|
||||
// Create an array of pointers to the key and value caches.
|
||||
int64_t key_cache_ptrs[num_layers];
|
||||
int64_t value_cache_ptrs[num_layers];
|
||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||
key_cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
|
||||
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
||||
int num_pairs = block_mapping.size(0);
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
// NOTE: This synchronizes the CPU and GPU.
|
||||
torch::Tensor key_cache_ptrs_tensor =
|
||||
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor =
|
||||
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
// copy blocks kernel for MLA (assumes a joint KV-cache)
|
||||
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
int num_layers = kv_caches.size();
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
torch::Device cache_device = kv_caches[0].device();
|
||||
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");
|
||||
|
||||
std::vector<int64_t> cache_ptrs(num_layers);
|
||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||
cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
|
||||
}
|
||||
torch::Tensor cache_ptrs_tensor =
|
||||
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
|
||||
int num_pairs = block_mapping.size(0);
|
||||
// We use the stride instead of numel in case the cache is padded for memory
|
||||
// alignment reasons, we assume the blocks data (inclusive of any padding)
|
||||
// is contiguous in memory
|
||||
int mem_footprint_per_block = kv_caches[0].stride(0);
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, mem_footprint_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
|
||||
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Used to copy/convert one element
|
||||
|
||||
@ -685,16 +685,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
||||
|
||||
// Copy the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
||||
|
||||
cache_ops.def(
|
||||
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
|
||||
@ -40,93 +40,6 @@ KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
num_mappings: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 2 * num_mappings <= num_blocks
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
|
||||
block_mapping: list[tuple[int, int]] = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_layers,
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device=device
|
||||
).view(-1, 2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.copy_blocks,
|
||||
(key_caches, value_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(head_size == HEAD_SIZES[0]),
|
||||
)
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
||||
|
||||
# Run the reference implementation.
|
||||
for src, dst in block_mapping:
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@ -763,73 +676,6 @@ def test_concat_and_cache_ds_mla(
|
||||
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
kv_caches = []
|
||||
for _ in range(num_layers):
|
||||
kv_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
kv_caches.append(kv_cache)
|
||||
|
||||
ref_caches = [kv_cache.clone() for kv_cache in kv_caches]
|
||||
|
||||
num_mappings = min(2, num_blocks // 2)
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining, 2 * num_mappings)
|
||||
block_mapping = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device=device
|
||||
).view(-1, 2)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
for ref_cache in ref_caches:
|
||||
ref_cache[dst].copy_(ref_cache[src])
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.copy_blocks_mla,
|
||||
(kv_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
ops.copy_blocks_mla(kv_caches, block_mapping_tensor)
|
||||
|
||||
for kv_cache, ref_cache in zip(kv_caches, ref_caches):
|
||||
torch.testing.assert_close(kv_cache, ref_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
|
||||
@ -2328,18 +2328,6 @@ def concat_and_cache_mla(
|
||||
)
|
||||
|
||||
|
||||
def copy_blocks(
|
||||
key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
block_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
|
||||
def copy_blocks_mla(kv_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None:
|
||||
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
|
||||
|
||||
|
||||
def swap_blocks(
|
||||
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
|
||||
) -> None:
|
||||
|
||||
@ -383,18 +383,6 @@ class ipex_ops:
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
block_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.xpu.copy_blocks( # type: ignore
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user