mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:38:38 +08:00
[TPU][Bugfix] fix test_pallas (#20666)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
4ac9c33f78
commit
eb58f5953d
@ -50,6 +50,7 @@ def test_ragged_paged_attention():
|
|||||||
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
|
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
|
||||||
max_num_reqs = 8
|
max_num_reqs = 8
|
||||||
max_num_blocks_per_req = 8
|
max_num_blocks_per_req = 8
|
||||||
|
num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32)
|
||||||
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32)
|
context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32)
|
||||||
@ -65,6 +66,7 @@ def test_ragged_paged_attention():
|
|||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
|
num_kv_update_slices=num_kv_update_slices,
|
||||||
num_slices_per_kv_cache_update_block=8,
|
num_slices_per_kv_cache_update_block=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user