mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
Fix paged attention testing. (#495)
Signed-off-by: Tao Peng <jiankeng.pt@alibaba-inc.com>
This commit is contained in:
parent
7d5a155e4a
commit
d7a1c6d614
@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
num_kv_heads: int = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
qkv = torch.empty(num_tokens,
|
qkv = torch.empty(num_tokens,
|
||||||
3,
|
3,
|
||||||
@ -202,6 +203,14 @@ def run_single_query_cached_kv_attention(
|
|||||||
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
|
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
|
||||||
|
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
assert num_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_heads // num_kv_heads
|
||||||
|
head_mapping = torch.repeat_interleave(
|
||||||
|
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||||
|
num_queries_per_kv)
|
||||||
|
|
||||||
output = torch.empty(num_tokens,
|
output = torch.empty(num_tokens,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_size,
|
head_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user