mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 00:35:37 +08:00
[Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1 (#13305)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com> Signed-off-by: maleksan85 <maleksan@amd.com> Signed-off-by: <> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: qli88 <qiang.li2@amd.com> Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
This commit is contained in:
parent
f67e9e9f22
commit
bc7c4d206b
@ -195,15 +195,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
"block_size": 16,
|
||||
"max_num_batched_tokens": 2,
|
||||
"max_num_seqs": 2,
|
||||
}, {
|
||||
"block_size": 8,
|
||||
"block_size": 16,
|
||||
"max_num_batched_tokens": 3,
|
||||
"max_num_seqs": 2,
|
||||
}, {
|
||||
"block_size": 8,
|
||||
"block_size": 16,
|
||||
"max_num_batched_tokens": 256,
|
||||
"max_num_seqs": 10,
|
||||
}])
|
||||
|
||||
@ -16,11 +16,24 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
|
||||
if triton.__version__ >= "2.1.0":
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
# Here's an example autotuner config for this kernel. This config does provide
|
||||
# a performance improvement, but dramatically increases first call latency in
|
||||
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||
# "num_unroll_cache": 4, \
|
||||
# "num_unroll_request": 1 } | \
|
||||
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||
# if current_platform.is_rocm() else {}), \
|
||||
# num_warps=4, \
|
||||
# num_stages=1)
|
||||
# ],
|
||||
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||
# )
|
||||
@triton.jit
|
||||
def _fwd_kernel(Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
@ -31,8 +44,7 @@ if triton.__version__ >= "2.1.0":
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
block_size,
|
||||
x,
|
||||
x: tl.constexpr,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
@ -51,21 +63,25 @@ if triton.__version__ >= "2.1.0":
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_bl: tl.constexpr,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
num_queries_per_kv: tl.constexpr,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
):
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0):
|
||||
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@ -88,6 +104,8 @@ if triton.__version__ >= "2.1.0":
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
# [BLOCK_SIZE]; starts at 0
|
||||
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||
# [N]; starts at 0
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# [D]; starts at 0
|
||||
@ -95,8 +113,7 @@ if triton.__version__ >= "2.1.0":
|
||||
# [M]; starts at current position in query
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# [M,D]
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
@ -109,51 +126,54 @@ if triton.__version__ >= "2.1.0":
|
||||
other=0.0) # [M,D]
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
|
||||
dtype=tl.float32) # [M,D]
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||
|
||||
# compute query against context (no causal mask here)
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
|
||||
loop_unroll_factor=num_unroll_cache):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0) # [N]
|
||||
# [D,N]
|
||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||
cur_kv_head * stride_k_cache_h +
|
||||
(start_n // BLOCK_SIZE) * stride_b_loc_s)
|
||||
# [D,BLOCK_SIZE]
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) *
|
||||
stride_k_cache_bl +
|
||||
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
# [N,D]
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs +
|
||||
|
||||
# [BLOCK_SIZE,D]
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k_load = tl.load(K_cache + off_k,
|
||||
offs_bs_n[:, None] * stride_v_cache_bl)
|
||||
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
k_load = tl.load(
|
||||
K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
else:
|
||||
k_load = tl.load(K_cache + off_k)
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
if SLIDING_WINDOW > 0:
|
||||
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||
# Q entries in sequence
|
||||
# (start_n + offs_n[None, :]) are the positions of
|
||||
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||
# KV entries in sequence
|
||||
# So the condition makes sure each entry in Q only attends
|
||||
# to KV entries not more than SLIDING_WINDOW away.
|
||||
@ -163,31 +183,27 @@ if triton.__version__ >= "2.1.0":
|
||||
# This then makes m_ij contain -inf, which causes NaNs in
|
||||
# exp().
|
||||
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
||||
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||
-10000)
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1) # [M]
|
||||
p = tl.exp(qk - m_ij[:, None]) # [M,N]
|
||||
l_ij = tl.sum(p, 1) # [M]
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij) # [M]
|
||||
alpha = tl.exp(m_i - m_i_new) # [M]
|
||||
beta = tl.exp(m_ij - m_i_new) # [M]
|
||||
l_i_new = alpha * l_i + beta * l_ij # [M]
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v_load = tl.load(V_cache + off_v,
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
v_load = tl.load(
|
||||
V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0) # [N,D]
|
||||
else:
|
||||
v_load = tl.load(V_cache + off_v)
|
||||
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
@ -196,8 +212,8 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
@ -210,7 +226,9 @@ if triton.__version__ >= "2.1.0":
|
||||
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||
|
||||
# compute query against itself (with causal mask)
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
for start_n in tl.range(0, \
|
||||
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
|
||||
loop_unroll_factor=num_unroll_request):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
@ -227,25 +245,16 @@ if triton.__version__ >= "2.1.0":
|
||||
float("-inf"))
|
||||
if SLIDING_WINDOW > 0:
|
||||
qk = tl.where(
|
||||
offs_m[:, None] - (start_n + offs_n[None, :])
|
||||
< SLIDING_WINDOW, qk, -10000)
|
||||
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||
qk, -10000)
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
@ -256,21 +265,23 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_query_len))
|
||||
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_flash_attn_v2(
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_flash_attn_v2(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
@ -311,7 +322,7 @@ if triton.__version__ >= "2.1.0":
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
@ -328,13 +339,11 @@ if triton.__version__ >= "2.1.0":
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=offs_m[:, None]
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
@ -349,14 +358,12 @@ if triton.__version__ >= "2.1.0":
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||
cur_kv_head * stride_k_cache_h +
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) *
|
||||
stride_k_cache_bl +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs +
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
@ -450,8 +457,7 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
# acc /= l_i[:, None]
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
@ -459,8 +465,9 @@ if triton.__version__ >= "2.1.0":
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
@ -506,7 +513,7 @@ if triton.__version__ >= "2.1.0":
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_N: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
):
|
||||
):
|
||||
# attn_bias[]
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@ -533,8 +540,7 @@ if triton.__version__ >= "2.1.0":
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
@ -551,8 +557,7 @@ if triton.__version__ >= "2.1.0":
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(
|
||||
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = 0
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
@ -561,14 +566,12 @@ if triton.__version__ >= "2.1.0":
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||
cur_kv_head * stride_k_cache_h +
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) *
|
||||
stride_k_cache_bl +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs +
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
@ -592,8 +595,8 @@ if triton.__version__ >= "2.1.0":
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||
alibi, float("-inf"))
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
@ -640,8 +643,7 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
# init alibi
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(
|
||||
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = cur_batch_ctx_len
|
||||
# # init debugger
|
||||
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||
@ -650,10 +652,9 @@ if triton.__version__ >= "2.1.0":
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :])
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
@ -667,8 +668,8 @@ if triton.__version__ >= "2.1.0":
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||
alibi, float("-inf"))
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
@ -688,10 +689,9 @@ if triton.__version__ >= "2.1.0":
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None])
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
@ -704,8 +704,7 @@ if triton.__version__ >= "2.1.0":
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
@ -714,8 +713,9 @@ if triton.__version__ >= "2.1.0":
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(q,
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
@ -735,10 +735,6 @@ if triton.__version__ >= "2.1.0":
|
||||
skip_decode=False):
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
|
||||
# Turing does have tensor core for float32 multiplication
|
||||
# use ieee as fallback for triton kernels work. There is also
|
||||
@ -779,13 +775,18 @@ if triton.__version__ >= "2.1.0":
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
assert batch + 1 == len(b_start_loc)
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
||||
|
||||
# 0 means "disable"
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if alibi_slopes is not None:
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
# batch, head,
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
@ -821,8 +822,7 @@ if triton.__version__ >= "2.1.0":
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4
|
||||
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
@ -840,6 +840,13 @@ if triton.__version__ >= "2.1.0":
|
||||
)
|
||||
return
|
||||
|
||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||
extra_kargs = {}
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
|
||||
|
||||
grid = lambda META: (batch, head,
|
||||
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
@ -852,7 +859,6 @@ if triton.__version__ >= "2.1.0":
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
v_cache.shape[3],
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
@ -878,17 +884,19 @@ if triton.__version__ >= "2.1.0":
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_SIZE=v_cache.shape[3],
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
num_warps=NUM_WARPS,
|
||||
BLOCK_M=128,
|
||||
BLOCK_N=64,
|
||||
num_unroll_cache=4,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
**extra_kargs)
|
||||
return
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user