mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:35:40 +08:00
[Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill (#4128)
This commit is contained in:
parent
53b018edcb
commit
e8cc7967ff
@ -10,7 +10,7 @@ from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
|||||||
|
|
||||||
NUM_HEADS = [64]
|
NUM_HEADS = [64]
|
||||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||||
HEAD_SIZES = [128]
|
HEAD_SIZES = [128, 96]
|
||||||
DTYPES = [torch.float16]
|
DTYPES = [torch.float16]
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
|||||||
@ -47,7 +47,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
stride_v_cache_bl,
|
stride_v_cache_bl,
|
||||||
num_queries_per_kv: int,
|
num_queries_per_kv: int,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr, # head size
|
||||||
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
@ -59,26 +60,30 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
|
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
|
||||||
|
|
||||||
block_start_loc = BLOCK_M * start_m
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
# initialize offsets
|
# initialize offsets
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
off_q = (
|
off_q = (
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||||
|
|
||||||
q = tl.load(
|
dim_mask = tl.where(
|
||||||
Q + off_q,
|
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
||||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
|
||||||
|
q = tl.load(Q + off_q,
|
||||||
|
mask=dim_mask[None, :] &
|
||||||
|
(offs_m[:, None] < cur_batch_query_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
# # initialize pointer to m and l
|
# # initialize pointer to m and l
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
@ -99,7 +104,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
offs_d[None, :] * stride_v_cache_d +
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
k = tl.load(K_cache + off_k,
|
k = tl.load(K_cache + off_k,
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
mask=dim_mask[:, None] &
|
||||||
|
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
@ -126,7 +132,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
acc = acc * acc_scale[:, None]
|
acc = acc * acc_scale[:, None]
|
||||||
# update acc
|
# update acc
|
||||||
v = tl.load(V_cache + off_v,
|
v = tl.load(V_cache + off_v,
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
mask=dim_mask[None, :] &
|
||||||
|
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
@ -142,16 +149,15 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
v_ptrs = V + off_v
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
block_mask = tl.where(
|
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
|
||||||
|
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
k = tl.load(k_ptrs +
|
k = tl.load(k_ptrs +
|
||||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
mask=(start_n + offs_n[None, :]) <
|
mask=dim_mask[:, None] &
|
||||||
cur_batch_seq_len - cur_batch_ctx_len,
|
((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
@ -179,8 +185,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
# update acc
|
# update acc
|
||||||
v = tl.load(v_ptrs +
|
v = tl.load(v_ptrs +
|
||||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
mask=(start_n + offs_n[:, None]) <
|
mask=dim_mask[None, :] &
|
||||||
cur_batch_seq_len - cur_batch_ctx_len,
|
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
@ -195,7 +201,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
out_ptrs = Out + off_o
|
out_ptrs = Out + off_o
|
||||||
tl.store(out_ptrs,
|
tl.store(out_ptrs,
|
||||||
acc,
|
acc,
|
||||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
mask=dim_mask[None, :] &
|
||||||
|
(offs_m[:, None] < cur_batch_query_len))
|
||||||
return
|
return
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -636,7 +643,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
# shape constraints
|
# shape constraints
|
||||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
assert Lq == Lk and Lk == Lv
|
assert Lq == Lk and Lk == Lv
|
||||||
assert Lk in {16, 32, 64, 128}
|
# round up Lk to a power of 2 - this is required for Triton block size
|
||||||
|
Lk_padded = 2**((Lk - 1).bit_length())
|
||||||
|
|
||||||
sm_scale = 1.0 / (Lq**0.5)
|
sm_scale = 1.0 / (Lq**0.5)
|
||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
@ -646,6 +654,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
|
|
||||||
num_warps = 8 if Lk <= 64 else 8
|
num_warps = 8 if Lk <= 64 else 8
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
|
assert Lk == Lk_padded
|
||||||
_fwd_kernel_alibi[grid](
|
_fwd_kernel_alibi[grid](
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -738,6 +747,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
num_queries_per_kv=num_queries_per_kv,
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
BLOCK_M=BLOCK,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user