mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 23:36:32 +08:00
[CI/Build] Fix test_prefix_prefill for AMD (#28905)
Signed-off-by: Ryan Rock <ryan.rock@amd.com>
This commit is contained in:
parent
2fd893b4ce
commit
68d7231991
@ -174,11 +174,11 @@ def test_contexted_kv_attention(
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
|
||||
torch.int32
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi(
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
|
||||
torch.int32
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user