From c91b64f749d3deeabf1b035db50553088e1e5c2c Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Mon, 10 Mar 2025 18:37:29 -0700 Subject: [PATCH] [neuron] add reshape_and_cache (#14391) --- tests/neuron/test_cache.py | 83 ++++++++++++++++++++++++++++ vllm/attention/ops/nki_flash_attn.py | 43 ++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 tests/neuron/test_cache.py diff --git a/tests/neuron/test_cache.py b/tests/neuron/test_cache.py new file mode 100644 index 000000000000..ea33727b7cfa --- /dev/null +++ b/tests/neuron/test_cache.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.attention.ops.nki_flash_attn import reshape_and_cache + + +@pytest.mark.parametrize( + "num_tokens, n_kv_head, d_head, num_blocks, block_size", + [ + # Small model configuration (e.g., GPT-2 small) + (32, 12, 64, 4, 128), # Typical sequence processing + (1, 12, 64, 4, 128), # Single token update + (128, 12, 64, 4, 128), # Longer sequence + + # Medium model configuration (e.g., GPT-2 medium) + (64, 16, 96, 8, 256), # Standard batch + (256, 16, 96, 8, 256), # Large batch + + # Large model configuration (e.g., GPT-3 style) + (48, 32, 128, 16, 512), # Typical processing window + (512, 32, 128, 16, 512), # Full context window + + # Edge cases and stress tests + (1024, 8, 32, 32, 32), # Many tokens, small heads + (16, 64, 256, 4, 64), # Few tokens, many heads + (2048, 24, 128, 64, 128), # Large scale test + + # Minimal configurations for debugging + (4, 2, 16, 2, 16), # Tiny test case + (1, 1, 8, 1, 8), # Minimal possible + ]) +def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks, + block_size): + # Set random seed for reproducibility + torch.manual_seed(42) + + # Create CPU tensors for reference implementation + key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( + torch.tensor(d_head)) + value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( + torch.tensor(d_head)) + key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) + value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) + slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens] + + # Run reference implementation on CPU + block_indices = torch.div(slot_mapping_cpu, + block_size, + rounding_mode="floor") + block_offsets = slot_mapping_cpu % block_size + + for i in range(num_tokens): + block_idx = block_indices[i] + block_offset = block_offsets[i] + key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i] + value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i] + + # Create XLA device tensors + device = torch.device('xla') + key = key_cpu.to(device) + value = value_cpu.to(device) + key_cache = torch.zeros_like(key_cache_cpu, device=device) + value_cache = torch.zeros_like(value_cache_cpu, device=device) + slot_mapping = slot_mapping_cpu.to(device) + + # Run vectorized implementation on XLA device + reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + + # Move results back to CPU for comparison + key_cache_result = key_cache.cpu() + value_cache_result = value_cache.cpu() + + # Assert results match + torch.testing.assert_close(key_cache_result, + key_cache_cpu, + rtol=1e-5, + atol=1e-5) + torch.testing.assert_close(value_cache_result, + value_cache_cpu, + rtol=1e-5, + atol=1e-5) diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py index 20f9dcd163fe..dcf9b0ef1f2a 100644 --- a/vllm/attention/ops/nki_flash_attn.py +++ b/vllm/attention/ops/nki_flash_attn.py @@ -869,3 +869,46 @@ def flash_attn_varlen_nkifunc( o = flash_paged_attention[1, n_kv_head](**kwargs) return o + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + """ + Writes key-value pairs to the KV cache at specified positions. + + Args: + key (torch.Tensor): Key tensor with shape + (num_tokens, n_kv_head, d_head) + value (torch.Tensor): Value tensor with shape + (num_tokens, n_kv_head, d_head) + key_cache (torch.Tensor): Key cache tensor with shape + (num_blocks, n_kv_head, block_size, d_head) + value_cache (torch.Tensor): Value cache tensor with shape + (num_blocks, n_kv_head, block_size, d_head) + slot_mapping (torch.Tensor): Mapping tensor indicating cache positions + with shape (num_tokens) + + Returns: + None: Updates the key_cache and value_cache tensors in-place + """ + block_size = key_cache.size(2) + + # Calculate indices with explicit floor division + block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_offsets = slot_mapping % block_size + + # Update caches using index_put_ + key_cache.index_put_( + (block_indices.unsqueeze(1), + torch.arange(key_cache.size(1), + device=key.device), block_offsets.unsqueeze(1)), key) + + value_cache.index_put_( + (block_indices.unsqueeze(1), + torch.arange(value_cache.size(1), + device=value.device), block_offsets.unsqueeze(1)), value)