diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py new file mode 100644 index 0000000000000..8c5a63653db9f --- /dev/null +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import numpy as np +import pytest +import torch + +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata +from vllm.v1.attention.backends.utils import ( + make_local_attention_virtual_batches) + + +@dataclass +class LocalAttentionTestData: + # Input parameters + batch_spec: BatchSpec + attn_chunk_size: int + block_size: int + # Expected return values + expected_q_seqlens: list[int] + expected_k_seqlens: list[int] + expected_local_block_table: list[list[int]] + + +test_data_list = [ + # Same as example in docstring of make_local_attention_virtual_batches + # except block table has 9 columns instead of 10 + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[4, 10, 5], + seq_lens=[6, 17, 9], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[2, 2, 1, 4, 4, 1, 4, 1], + expected_k_seqlens=[4, 2, 4, 4, 4, 1, 4, 1], + # 2 pages per local branch + # (chunk size 4 // block size 2) + expected_local_block_table=[ + [0, 1], # local-batch 0, (batch 0, starting from k[0]) + [2, 3], # local-batch 1, (batch 0, starting from k[4]) + [11, 12], # local-batch 2, (batch 1, starting from k[4]) + [13, 14], # local-batch 3, (batch 1, starting from k[8]) + [15, 16], # local-batch 4, (batch 1, starting from k[12]) + [17, 17], # local-batch 5, (batch 1, starting from k[16]) + [20, 21], # local-batch 6, (batch 2, starting from k[4]) + [22, 23], # local-batch 7, (batch 2, starting from k[8]) + ]), + # Case where block indices are not clipped to block table ncols-1 + # because tokens_in_last_block == attn_chunk_size + LocalAttentionTestData(batch_spec=BatchSpec( + query_lens=[8], + seq_lens=[12], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[4, 4], + expected_k_seqlens=[4, 4], + expected_local_block_table=[ + [2, 3], + [4, 5], + ]), + # Case where all kv_seq positions are involved in attn + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[7], + # 10 - 7 = 3 previously computed tokens + seq_lens=[10], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[1, 4, 2], + expected_k_seqlens=[4, 4, 2], + expected_local_block_table=[ + [0, 1], + [2, 3], + [4, 4], + ]), + # Case where attn_chunk_size > kv_seq_len + # so no extra mini virtual batches are created + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[4], + seq_lens=[6], + ), + # Larger than kv_seq_len + attn_chunk_size=10, + block_size=2, + # No change to q_seqlens and k_seqlens + expected_q_seqlens=[4], + expected_k_seqlens=[6], + # In this case, we only need a block-table like: + # block_table = [ [0, 1, 2] ] # 1 batch, 3 pages + # But we need to pad it to 5 pages per local batch + # because currently the pages_per_local_batch + # is calculated as (attn_chunk_size // block_size) + expected_local_block_table=[ + [0, 1, 2, 2, 2], + ]), + # Block size equal to chunk size + # Expect single page per batch in local batch table + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[6, 6], + seq_lens=[8, 8], + ), + attn_chunk_size=4, + block_size=4, + expected_q_seqlens=[2, 4, 2, 4], + expected_k_seqlens=[4, 4, 4, 4], + # Initial block table = [ + # [0, 1], < batch 0 + # [2, 3], < batch 1 + # ] + expected_local_block_table=[ + [0], # local-batch 0, (batch 0, starting from k[0]) + [1], # local-batch 1, (batch 0, starting from k[4]) + [2], # local-batch 1, (batch 0, starting from k[0]) + [3], # local-batch 1, (batch 0, starting from k[4]) + ]), + # Case where query falls in the second attention chunk + # k_toks > 0 1 2 3 4 + # q_toks v _____________ + # 0 | 1 + # 1 | 1 1 + # 2 | 1 1 1 + # 3 | 1 1 1 1 + # 4 | 1 + # where tokens 0,1,2,3 have been pre-computed + LocalAttentionTestData(batch_spec=BatchSpec( + query_lens=[1], + seq_lens=[5], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[1], + expected_k_seqlens=[1], + expected_local_block_table=[ + [2, 2], + ]), +] + + +@pytest.mark.parametrize("test_data", test_data_list) +def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): + device = torch.device("cuda:0") + batch_spec = test_data.batch_spec + attn_chunk_size = test_data.attn_chunk_size + block_size = test_data.block_size + expected_q_seqlens = test_data.expected_q_seqlens + expected_k_seqlens = test_data.expected_k_seqlens + expected_local_block_table = test_data.expected_local_block_table + + # Create common attention metadata + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size, + device, + # Use torch.arange instead of torch.randint so we can assert on + # block table tensor values. The block table will have shape + # (num_batches, cdiv(max_seq_len, block_size)) and the values will be + # aranged from 0 to cdiv(max_seq_len, block_size)-1 + arange_block_indices=True, + ) + + # Call the function + result = make_local_attention_virtual_batches(attn_chunk_size, + common_attn_metadata, + block_size) + + # Convert to numpy for easier comparison + actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy()) + actual_k_seqlens = result.seq_lens_cpu.numpy() + + # Check that all query lengths are less than or equal to attn_chunk_size + assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens) + # Check that all key lengths are less than or equal to attn_chunk_size + assert all(k_len <= attn_chunk_size for k_len in actual_k_seqlens) + # Check that the total number of query tokens is preserved + assert sum(actual_q_seqlens) == sum(batch_spec.query_lens) + + # Verify results + np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens) + np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens) + + expected_block_table_tensor =\ + torch.tensor(expected_local_block_table, + dtype=torch.int32, + device=device) + + print(f"Expected block table:\n{expected_block_table_tensor}") + print(f"Actual block table:\n{result.block_table_tensor}") + + torch.testing.assert_close(result.block_table_tensor, + expected_block_table_tensor) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index ae2ab6e6413c0..be6cfce6fba8a 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -40,7 +40,8 @@ def create_common_attn_metadata( batch_spec: BatchSpec, block_size: int, device: torch.device, - max_block_idx: int = 1000) -> CommonAttentionMetadata: + max_block_idx: int = 1000, + arange_block_indices: bool = False) -> CommonAttentionMetadata: """Create CommonAttentionMetadata from a BatchSpec and ModelParams.""" # Create query start locations query_start_loc = torch.zeros(batch_spec.batch_size + 1, @@ -65,19 +66,28 @@ def create_common_attn_metadata( ] num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) - # Create block table (random for testing) + # Create block table and slot mapping max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size - block_table_tensor = torch.randint(0, - max_block_idx, - (batch_spec.batch_size, max_blocks), - dtype=torch.int32, - device=device) - - # Create slot mapping - slot_mapping = torch.randint(0, - max_block_idx, (num_tokens, ), - dtype=torch.int64, - device=device) + if arange_block_indices: + num_blocks = batch_spec.batch_size * max_blocks + block_table_tensor = torch.arange(num_blocks, + dtype=torch.int32, + device=device).view( + batch_spec.batch_size, + max_blocks) + slot_mapping = torch.arange(num_tokens, + dtype=torch.int64, + device=device).view(num_tokens) + else: + block_table_tensor = torch.randint(0, + max_block_idx, + (batch_spec.batch_size, max_blocks), + dtype=torch.int32, + device=device) + slot_mapping = torch.randint(0, + max_block_idx, (num_tokens, ), + dtype=torch.int64, + device=device) # Calculate max query length max_query_len = max(batch_spec.query_lens)