[Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1 (#13305)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Signed-off-by: maleksan85 <maleksan@amd.com>
Signed-off-by: <>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
This commit is contained in:
Aleksandr Malyshev 2025-04-22 19:11:56 -07:00 committed by GitHub
parent f67e9e9f22
commit bc7c4d206b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 805 additions and 797 deletions

View File

@ -195,15 +195,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
])
@pytest.mark.parametrize("per_test_common_llm_kwargs",
[{
"block_size": 8,
"block_size": 16,
"max_num_batched_tokens": 2,
"max_num_seqs": 2,
}, {
"block_size": 8,
"block_size": 16,
"max_num_batched_tokens": 3,
"max_num_seqs": 2,
}, {
"block_size": 8,
"block_size": 16,
"max_num_batched_tokens": 256,
"max_num_seqs": 10,
}])

View File

@ -16,11 +16,24 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)
if triton.__version__ >= "2.1.0":
# Here's an example autotuner config for this kernel. This config does provide
# a performance improvement, but dramatically increases first call latency in
# triton 3.2. Because of this tradeoff, it's currently commented out.
# @triton.autotune(
# configs=[
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
# "num_unroll_cache": 4, \
# "num_unroll_request": 1 } | \
# ({"kpack": 2, "waves_per_eu": 2} \
# if current_platform.is_rocm() else {}), \
# num_warps=4, \
# num_stages=1)
# ],
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
# )
@triton.jit
def _fwd_kernel(
Q,
def _fwd_kernel(Q,
K,
V,
K_cache,
@ -31,8 +44,7 @@ if triton.__version__ >= "2.1.0":
v_scale,
B_Start_Loc,
B_Seqlen,
block_size,
x,
x: tl.constexpr,
Out,
stride_b_loc_b,
stride_b_loc_s,
@ -51,21 +63,25 @@ if triton.__version__ >= "2.1.0":
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_bl: tl.constexpr,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
num_queries_per_kv: tl.constexpr,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_PADDED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
num_unroll_cache: tl.constexpr,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
):
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
@ -88,6 +104,8 @@ if triton.__version__ >= "2.1.0":
block_start_loc = BLOCK_M * start_m
# initialize offsets
# [BLOCK_SIZE]; starts at 0
offs_bs_n = tl.arange(0, BLOCK_SIZE)
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
@ -95,8 +113,7 @@ if triton.__version__ >= "2.1.0":
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
@ -109,51 +126,54 @@ if triton.__version__ >= "2.1.0":
other=0.0) # [M,D]
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
dtype=tl.float32) # [M,D]
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
loop_unroll_factor=num_unroll_cache):
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) # [N]
# [D,N]
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(start_n // BLOCK_SIZE) * stride_b_loc_s)
# [D,BLOCK_SIZE]
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
# [N,D]
off_v = (
bn[:, None] * stride_v_cache_bs +
# [BLOCK_SIZE,D]
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
offs_bs_n[:, None] * stride_v_cache_bl)
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
k_load = tl.load(
K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
else:
k_load = tl.load(K_cache + off_k)
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# (start_n + offs_bs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
@ -163,31 +183,27 @@ if triton.__version__ >= "2.1.0":
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1) # [M]
p = tl.exp(qk - m_ij[:, None]) # [M,N]
l_ij = tl.sum(p, 1) # [M]
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij) # [M]
alpha = tl.exp(m_i - m_i_new) # [M]
beta = tl.exp(m_ij - m_i_new) # [M]
l_i_new = alpha * l_i + beta * l_ij # [M]
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
v_load = tl.load(
V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
else:
v_load = tl.load(V_cache + off_v)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
@ -196,8 +212,8 @@ if triton.__version__ >= "2.1.0":
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
l_i = l_i * alpha + l_ij
m_i = m_ij
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
@ -210,7 +226,9 @@ if triton.__version__ >= "2.1.0":
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
for start_n in tl.range(0, \
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
loop_unroll_factor=num_unroll_request):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
@ -227,25 +245,16 @@ if triton.__version__ >= "2.1.0":
float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :])
< SLIDING_WINDOW, qk, -10000)
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
qk, -10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
@ -256,19 +265,21 @@ if triton.__version__ >= "2.1.0":
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
l_i = l_i * alpha + l_ij
m_i = m_ij
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len))
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
@ -328,13 +339,11 @@ if triton.__version__ >= "2.1.0":
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(Q + off_q,
mask=offs_m[:, None]
< cur_batch_seq_len - cur_batch_ctx_len,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
@ -349,14 +358,12 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
@ -450,8 +457,7 @@ if triton.__version__ >= "2.1.0":
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
@ -459,6 +465,7 @@ if triton.__version__ >= "2.1.0":
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
@ -533,8 +540,7 @@ if triton.__version__ >= "2.1.0":
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
@ -551,8 +557,7 @@ if triton.__version__ >= "2.1.0":
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
@ -561,14 +566,12 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
@ -592,8 +595,8 @@ if triton.__version__ >= "2.1.0":
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
@ -640,8 +643,7 @@ if triton.__version__ >= "2.1.0":
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
@ -650,10 +652,9 @@ if triton.__version__ >= "2.1.0":
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :])
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
@ -667,8 +668,8 @@ if triton.__version__ >= "2.1.0":
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
@ -688,10 +689,9 @@ if triton.__version__ >= "2.1.0":
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None])
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
@ -704,8 +704,7 @@ if triton.__version__ >= "2.1.0":
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
@ -714,6 +713,7 @@ if triton.__version__ >= "2.1.0":
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
@ -735,10 +735,6 @@ if triton.__version__ >= "2.1.0":
skip_decode=False):
q_dtype_is_f32 = q.dtype is torch.float32
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
@ -779,13 +775,18 @@ if triton.__version__ >= "2.1.0":
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# batch, head,
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
_fwd_kernel_alibi[grid](
q,
k,
@ -821,8 +822,7 @@ if triton.__version__ >= "2.1.0":
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
@ -840,6 +840,13 @@ if triton.__version__ >= "2.1.0":
)
return
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
if current_platform.is_rocm():
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
grid = lambda META: (batch, head,
triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid](
q,
k,
@ -852,7 +859,6 @@ if triton.__version__ >= "2.1.0":
v_scale,
b_start_loc,
b_seq_len,
v_cache.shape[3],
k_cache.shape[4],
o,
b_loc.stride(0),
@ -878,17 +884,19 @@ if triton.__version__ >= "2.1.0":
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_SIZE=v_cache.shape[3],
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,
num_unroll_request=1,
num_warps=4,
num_stages=1,
)
**extra_kargs)
return