Fix paged attention testing. (#495)

Signed-off-by: Tao Peng <jiankeng.pt@alibaba-inc.com>
This commit is contained in:
Tao Peng 2023-07-25 12:01:56 +08:00 committed by GitHub
parent 7d5a155e4a
commit d7a1c6d614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention(
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
num_kv_heads: int = None,
) -> None: ) -> None:
qkv = torch.empty(num_tokens, qkv = torch.empty(num_tokens,
3, 3,
@ -202,6 +203,14 @@ def run_single_query_cached_kv_attention(
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda") head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
output = torch.empty(num_tokens, output = torch.empty(num_tokens,
num_heads, num_heads,
head_size, head_size,