vllm/tests/v1/tpu/test_kv_cache_update_kernel.py
Harry Mellor d6953beb91
Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-05 07:06:22 -07:00

79 lines
2.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
import torch_xla
import vllm.v1.attention.backends.pallas # noqa: F401
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
@pytest.mark.parametrize("page_size", [32, 33])
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
def test_kv_cache_update_kernel(
page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int
):
page_num = 1000
padded_num_tokens = 128
kv_cache_cpu = torch.zeros(
(page_num * page_size, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu",
)
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
new_kv_cpu = torch.randn(
(padded_num_tokens, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu",
)
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32)
num_kv_update_slices = len(slice_lens)
kv_cache_start_indices = np.array(
[
page_size * 2 - 7,
page_size * 2,
page_size * 3,
page_size * 4 + 6,
page_size * 5 + 7,
page_size * 6 + 8,
page_size * 15 + 3,
],
dtype=np.int32,
)
new_kv_cache_indices = np.concatenate(
[np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])]
)
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1
)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
num_kv_update_slices_xla = torch.tensor(
[num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32
)
torch_xla.sync()
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla,
slot_mapping_xla,
kv_cache_xla,
num_kv_update_slices_xla,
page_size,
num_slices_per_block,
)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens):
kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :]
assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4)