diff --git a/csrc/cache.h b/csrc/cache.h index cbe44c09eb624..42ccb589683a9 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -9,16 +9,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping); - -void copy_blocks_mla(std::vector const& kv_caches, - const torch::Tensor& block_mapping); - void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f11c5f24c12ec..cf26ae544deaa 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -119,94 +119,6 @@ __global__ void copy_blocks_mla_kernel( } // namespace vllm -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { - int num_layers = key_caches.size(); - TORCH_CHECK(num_layers == value_caches.size()); - if (num_layers == 0) { - return; - } - torch::Device cache_device = key_caches[0].device(); - TORCH_CHECK(cache_device.is_cuda()); - - // Create data structures for the kernel. - // Create an array of pointers to the key and value caches. - int64_t key_cache_ptrs[num_layers]; - int64_t value_cache_ptrs[num_layers]; - for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = - reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = - reinterpret_cast(value_caches[layer_idx].data_ptr()); - } - - // block_mapping is a 2D tensor with shape (num_pairs, 2). - int num_pairs = block_mapping.size(0); - - // Move the data structures to the GPU. - // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = - torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) - .to(cache_device); - torch::Tensor value_cache_ptrs_tensor = - torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) - .to(cache_device); - - // Launch the kernel. - const int numel_per_block = key_caches[0][0].numel(); - dim3 grid(num_layers, num_pairs); - dim3 block(std::min(1024, numel_per_block)); - const at::cuda::OptionalCUDAGuard device_guard(cache_device); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), numel_per_block); - })); -} - -// copy blocks kernel for MLA (assumes a joint KV-cache) -void copy_blocks_mla(std::vector const& kv_caches, - const torch::Tensor& block_mapping) { - int num_layers = kv_caches.size(); - if (num_layers == 0) { - return; - } - torch::Device cache_device = kv_caches[0].device(); - TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA"); - - std::vector cache_ptrs(num_layers); - for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - cache_ptrs[layer_idx] = - reinterpret_cast(kv_caches[layer_idx].data_ptr()); - } - torch::Tensor cache_ptrs_tensor = - torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64) - .to(cache_device); - - int num_pairs = block_mapping.size(0); - // We use the stride instead of numel in case the cache is padded for memory - // alignment reasons, we assume the blocks data (inclusive of any padding) - // is contiguous in memory - int mem_footprint_per_block = kv_caches[0].stride(0); - dim3 grid(num_layers, num_pairs); - dim3 block(std::min(1024, mem_footprint_per_block)); - const at::cuda::OptionalCUDAGuard device_guard(cache_device); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] { - vllm::copy_blocks_mla_kernel<<>>( - cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), mem_footprint_per_block); - })); -} - namespace vllm { // Used to copy/convert one element diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 461f74ca184fd..6f2c8e915b5cb 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -685,16 +685,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); - // Copy the cache blocks from src to dst. - cache_ops.def( - "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " - "Tensor block_mapping) -> ()"); - cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); - - cache_ops.def( - "copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()"); - cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla); - // Reshape the key and value tensors and cache them. cache_ops.def( "reshape_and_cache(Tensor key, Tensor value," diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index acf46d75d62eb..3f76033254d32 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -40,93 +40,6 @@ KV_CACHE_DTYPE = ["auto", "fp8"] RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"] -@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) -@pytest.mark.parametrize("num_layers", NUM_LAYERS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@torch.inference_mode() -def test_copy_blocks( - kv_cache_factory, - num_mappings: int, - num_layers: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, - kv_cache_dtype: str, - device: str, -) -> None: - if kv_cache_dtype == "fp8" and head_size % 16: - pytest.skip() - current_platform.seed_everything(seed) - torch.set_default_device(device) - torch.cuda.set_device(device) - # Generate random block mappings where each source block is mapped to two - # destination blocks. - assert 2 * num_mappings <= num_blocks - src_blocks = random.sample(range(num_blocks), num_mappings) - remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) - dst_blocks = random.sample(remaining_blocks, 2 * num_mappings) - block_mapping: list[tuple[int, int]] = [] - for i in range(num_mappings): - src = src_blocks[i] - dst1 = dst_blocks[2 * i] - dst2 = dst_blocks[2 * i + 1] - block_mapping.append((src, dst1)) - block_mapping.append((src, dst2)) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory( - num_blocks, - block_size, - num_layers, - num_heads, - head_size, - kv_cache_dtype, - dtype, - seed, - device, - ) - - # Clone the KV caches. - cloned_key_caches = [key_cache.clone() for key_cache in key_caches] - cloned_value_caches = [value_cache.clone() for value_cache in value_caches] - - # Call the copy blocks kernel. - block_mapping_tensor = torch.tensor( - block_mapping, dtype=torch.int64, device=device - ).view(-1, 2) - - opcheck( - torch.ops._C_cache_ops.copy_blocks, - (key_caches, value_caches, block_mapping_tensor), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(head_size == HEAD_SIZES[0]), - ) - ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) - - # Run the reference implementation. - for src, dst in block_mapping: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) - - # Compare the results. - for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): - torch.testing.assert_close(key_cache, cloned_key_cache) - for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): - torch.testing.assert_close(value_cache, cloned_value_cache) - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -763,73 +676,6 @@ def test_concat_and_cache_ds_mla( torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) -@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) -@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) -@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) -@pytest.mark.parametrize("num_layers", NUM_LAYERS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@torch.inference_mode() -def test_copy_blocks_mla( - kv_lora_rank: int, - qk_rope_head_dim: int, - block_size: int, - num_blocks: int, - num_layers: int, - dtype: torch.dtype, - seed: int, - device: str, - kv_cache_dtype: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - torch.cuda.set_device(device) - - entry_size = kv_lora_rank + qk_rope_head_dim - - kv_caches = [] - for _ in range(num_layers): - kv_cache = _create_mla_cache( - num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device - ) - _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) - kv_caches.append(kv_cache) - - ref_caches = [kv_cache.clone() for kv_cache in kv_caches] - - num_mappings = min(2, num_blocks // 2) - src_blocks = random.sample(range(num_blocks), num_mappings) - remaining = list(set(range(num_blocks)) - set(src_blocks)) - dst_blocks = random.sample(remaining, 2 * num_mappings) - block_mapping = [] - for i in range(num_mappings): - src = src_blocks[i] - dst1 = dst_blocks[2 * i] - dst2 = dst_blocks[2 * i + 1] - block_mapping.append((src, dst1)) - block_mapping.append((src, dst2)) - block_mapping_tensor = torch.tensor( - block_mapping, dtype=torch.int64, device=device - ).view(-1, 2) - - for src, dst in block_mapping: - for ref_cache in ref_caches: - ref_cache[dst].copy_(ref_cache[src]) - - opcheck( - torch.ops._C_cache_ops.copy_blocks_mla, - (kv_caches, block_mapping_tensor), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - ) - ops.copy_blocks_mla(kv_caches, block_mapping_tensor) - - for kv_cache, ref_cache in zip(kv_caches, ref_caches): - torch.testing.assert_close(kv_cache, ref_cache) - - @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 78bd8d4e64115..c1519fc177250 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2328,18 +2328,6 @@ def concat_and_cache_mla( ) -def copy_blocks( - key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) - - -def copy_blocks_mla(kv_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: - torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) - - def swap_blocks( src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor ) -> None: diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 95c17cb331f67..239f5376eb462 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -383,18 +383,6 @@ class ipex_ops: ) return None - @staticmethod - def copy_blocks( - key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor, - ) -> None: - torch.xpu.copy_blocks( # type: ignore - key_caches, - value_caches, - block_mapping, - ) - @staticmethod def swap_blocks( src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor