mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:54:55 +08:00
Fix paged attention testing. (#495)
Signed-off-by: Tao Peng <jiankeng.pt@alibaba-inc.com>
This commit is contained in:
parent
7d5a155e4a
commit
d7a1c6d614
@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention(
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
num_kv_heads: int = None,
|
||||
) -> None:
|
||||
qkv = torch.empty(num_tokens,
|
||||
3,
|
||||
@ -202,6 +203,14 @@ def run_single_query_cached_kv_attention(
|
||||
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
assert num_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user