From d7a1c6d614756b3072df3e8b52c0998035fb453f Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Tue, 25 Jul 2023 12:01:56 +0800 Subject: [PATCH] Fix paged attention testing. (#495) Signed-off-by: Tao Peng --- tests/kernels/test_attention.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 4c02f33ca577..d8199c8e6075 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention( block_size: int, num_blocks: int, dtype: torch.dtype, + num_kv_heads: int = None, ) -> None: qkv = torch.empty(num_tokens, 3, @@ -202,6 +203,14 @@ def run_single_query_cached_kv_attention( head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda") scale = float(1.0 / (head_size**0.5)) + + num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + assert num_heads % num_kv_heads == 0 + num_queries_per_kv = num_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + output = torch.empty(num_tokens, num_heads, head_size,