From 68d7231991cc307d6865eac5bfca551c06f67465 Mon Sep 17 00:00:00 2001 From: Ryan Rock Date: Wed, 19 Nov 2025 15:04:36 -0600 Subject: [PATCH] [CI/Build] Fix test_prefix_prefill for AMD (#28905) Signed-off-by: Ryan Rock --- tests/kernels/attention/test_prefix_prefill.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 78cdbbbf7379d..e041e8c8d2ffa 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -174,11 +174,11 @@ def test_contexted_kv_attention( block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to( + torch.int32 ) for i in range(BS): for j in range(query_lens[i]): @@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi( block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to( + torch.int32 ) for i in range(BS): for j in range(query_lens[i]):