[CI/Build] Fix test_prefix_prefill for AMD (#28905)

Signed-off-by: Ryan Rock <ryan.rock@amd.com>
This commit is contained in:
Ryan Rock 2025-11-19 15:04:36 -06:00 committed by GitHub
parent 2fd893b4ce
commit 68d7231991
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -174,11 +174,11 @@ def test_contexted_kv_attention(
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
torch.int32
)
for i in range(BS):
for j in range(query_lens[i]):
@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi(
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
torch.int32
)
for i in range(BS):
for j in range(query_lens[i]):