mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:55:40 +08:00
[AMD][Kernel][Bugfix] Cast offsets tensor bn to tl.int64 to avoid GPU segfault (#23692)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
d328f7894f
commit
457e471971
@ -146,7 +146,7 @@ def _fwd_kernel(Q,
|
|||||||
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||||
(start_n // BLOCK_SIZE) * stride_b_loc_s)
|
(start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
|
||||||
# [D,BLOCK_SIZE]
|
# [D,BLOCK_SIZE]
|
||||||
off_k = (
|
off_k = (
|
||||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||||
@ -367,7 +367,7 @@ def _fwd_kernel_flash_attn_v2(
|
|||||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
other=0)
|
other=0).to(tl.int64)
|
||||||
off_k = (
|
off_k = (
|
||||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
@ -575,7 +575,7 @@ def _fwd_kernel_alibi(
|
|||||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
other=0)
|
other=0).to(tl.int64)
|
||||||
off_k = (
|
off_k = (
|
||||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user