diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index e1d41930f6231..a70db89cdb76e 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -146,7 +146,7 @@ def _fwd_kernel(Q, start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) + (start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64) # [D,BLOCK_SIZE] off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + @@ -367,7 +367,7 @@ def _fwd_kernel_flash_attn_v2( bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0).to(tl.int64) off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + @@ -575,7 +575,7 @@ def _fwd_kernel_alibi( bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0).to(tl.int64) off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d +