diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 78cdbbbf7379d..e041e8c8d2ffa 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -174,11 +174,11 @@ def test_contexted_kv_attention( block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to( + torch.int32 ) for i in range(BS): for j in range(query_lens[i]): @@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi( block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to( + torch.int32 ) for i in range(BS): for j in range(query_lens[i]):