mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 20:27:08 +08:00
[Kernel] [V1] Further optimizations to ROCm (Triton) Backend to better handle GQA. (#14431)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
This commit is contained in:
parent
0b1cfa6180
commit
fb4c7f8ef0
@ -1,9 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
# Authors:
|
# Authors:
|
||||||
# - Burkhard Ringlein
|
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||||
# - Jan van Lunteren
|
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||||
# - Thomas Parnell
|
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||||
|
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@ -31,6 +32,7 @@ def kernel_paged_attention_2d(
|
|||||||
v_scale, # float32
|
v_scale, # float32
|
||||||
num_query_heads: tl.constexpr, # int
|
num_query_heads: tl.constexpr, # int
|
||||||
num_queries_per_kv: tl.constexpr, # int
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
num_queries_per_kv_padded: tl.constexpr, # int
|
||||||
block_table_stride: tl.constexpr, # int
|
block_table_stride: tl.constexpr, # int
|
||||||
query_stride_0: tl.constexpr, # int
|
query_stride_0: tl.constexpr, # int
|
||||||
query_stride_1: tl.constexpr, # int, should be equal to head_size
|
query_stride_1: tl.constexpr, # int, should be equal to head_size
|
||||||
@ -55,8 +57,7 @@ def kernel_paged_attention_2d(
|
|||||||
query_start_len_ptr, # [num_seqs+1]
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
):
|
):
|
||||||
seq_idx = tl.program_id(0)
|
seq_idx = tl.program_id(0)
|
||||||
query_head_idx = tl.program_id(1)
|
kv_head_idx = tl.program_id(1)
|
||||||
kv_head_idx = query_head_idx // num_queries_per_kv
|
|
||||||
|
|
||||||
if filter_by_query_len:
|
if filter_by_query_len:
|
||||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
@ -69,31 +70,40 @@ def kernel_paged_attention_2d(
|
|||||||
else:
|
else:
|
||||||
cur_batch_in_all_start_index = seq_idx
|
cur_batch_in_all_start_index = seq_idx
|
||||||
|
|
||||||
|
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||||
|
0, num_queries_per_kv_padded)
|
||||||
|
|
||||||
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
||||||
query_head_idx * query_stride_1)
|
query_head_idx[:, None] * query_stride_1)
|
||||||
|
|
||||||
|
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||||
|
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||||
|
|
||||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||||
0).to(tl.int1)
|
0).to(tl.int1)
|
||||||
|
|
||||||
# Q : (HEAD_SIZE,)
|
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||||
Q = tl.load(
|
Q = tl.load(
|
||||||
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED),
|
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||||
mask=dim_mask,
|
mask=dim_mask[None, :] & head_mask[:, None],
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_table_offset = seq_idx * block_table_stride
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
M = tl.full([1], float("-inf"), dtype=tl.float32)
|
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||||
L = tl.full([1], 1.0, dtype=tl.float32)
|
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||||
acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32)
|
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||||
|
dtype=tl.float32)
|
||||||
|
|
||||||
# sequence len for this particular sequence
|
# sequence len for this particular sequence
|
||||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
# alibi slope for this head
|
# alibi slope for this head
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx)
|
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
|
||||||
|
mask=head_mask,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||||
|
|
||||||
@ -107,8 +117,8 @@ def kernel_paged_attention_2d(
|
|||||||
|
|
||||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||||
kv_head_idx * stride_v_cache_1 +
|
kv_head_idx * stride_v_cache_1 +
|
||||||
offs_d[:, None] * stride_v_cache_2 +
|
offs_d[None, :] * stride_v_cache_2 +
|
||||||
offs_n[None, :] * stride_v_cache_3)
|
offs_n[:, None] * stride_v_cache_3)
|
||||||
|
|
||||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||||
kv_head_idx * stride_k_cache_1 +
|
kv_head_idx * stride_k_cache_1 +
|
||||||
@ -126,9 +136,9 @@ def kernel_paged_attention_2d(
|
|||||||
else:
|
else:
|
||||||
K = K_load
|
K = K_load
|
||||||
|
|
||||||
# V : (HEAD_SIZE, BLOCK_SIZE)
|
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||||
V_load = tl.load(value_cache_ptr + v_offset,
|
V_load = tl.load(value_cache_ptr + v_offset,
|
||||||
mask=dim_mask[:, None],
|
mask=dim_mask[None, :],
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
if V_load.dtype.is_fp8():
|
if V_load.dtype.is_fp8():
|
||||||
@ -136,51 +146,59 @@ def kernel_paged_attention_2d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||||
mask_new = tmp < boundary
|
seq_mask = seq_offset[None, :] < boundary
|
||||||
# S : (BLOCK_SIZE,)
|
|
||||||
S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32)
|
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
S += scale * tl.sum(K * Q[:, None], axis=0)
|
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
|
||||||
|
float("-inf")).to(tl.float32)
|
||||||
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
|
context_len = seq_len - 1
|
||||||
|
|
||||||
if SLIDING_WINDOW > 0:
|
if SLIDING_WINDOW > 0:
|
||||||
S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000)
|
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
|
||||||
|
-10000)
|
||||||
|
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope * (tmp - seq_len + 1)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
# compute running maximum
|
# compute running maximum
|
||||||
# m_j : (1,)
|
# m_j : (num_queries_per_kv,)
|
||||||
m_j = tl.maximum(M, tl.max(S, axis=0))
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
# P : (BLOCK_SIZE,)
|
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
P = tl.exp(S - m_j)
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
# l_j : (1,)
|
# l_j : (num_queries_per_kv,)
|
||||||
l_j = tl.sum(P, axis=0)
|
l_j = tl.sum(P, axis=1)
|
||||||
|
|
||||||
# alpha : (1, )
|
# alpha : (num_queries_per_kv, )
|
||||||
alpha = tl.exp(M - m_j)
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
# acc : (BLOCK_SIZE,)
|
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
acc = acc * alpha
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
# update constants
|
# update constants
|
||||||
L = L * alpha + l_j
|
L = L * alpha + l_j
|
||||||
M = m_j
|
M = m_j
|
||||||
|
|
||||||
# acc : (BLOCK_SIZE,)
|
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||||
acc += tl.sum(V * P[None, :], axis=1)
|
acc += tl.dot(P.to(V.dtype), V)
|
||||||
|
|
||||||
# epilogue
|
# epilogue
|
||||||
acc = acc / L
|
acc = acc / L[:, None]
|
||||||
|
|
||||||
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||||
query_head_idx * output_stride_1)
|
query_head_idx * output_stride_1)
|
||||||
|
|
||||||
tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED),
|
tl.store(
|
||||||
acc,
|
output_ptr + output_offset[:, None] +
|
||||||
mask=dim_mask)
|
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||||
|
acc,
|
||||||
|
mask=dim_mask[None, :] & head_mask[:, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def chunked_prefill_paged_decode(
|
def chunked_prefill_paged_decode(
|
||||||
@ -234,6 +252,7 @@ def chunked_prefill_paged_decode(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs = len(seq_lens)
|
num_seqs = len(seq_lens)
|
||||||
num_query_heads = query.shape[1]
|
num_query_heads = query.shape[1]
|
||||||
|
num_kv_heads = key.shape[1]
|
||||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||||
head_size = query.shape[2]
|
head_size = query.shape[2]
|
||||||
|
|
||||||
@ -253,9 +272,12 @@ def chunked_prefill_paged_decode(
|
|||||||
key_cache = key_cache.view(target_dtype)
|
key_cache = key_cache.view(target_dtype)
|
||||||
value_cache = value_cache.view(target_dtype)
|
value_cache = value_cache.view(target_dtype)
|
||||||
|
|
||||||
|
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||||
|
16)
|
||||||
|
|
||||||
kernel_paged_attention_2d[(
|
kernel_paged_attention_2d[(
|
||||||
num_seqs,
|
num_seqs,
|
||||||
num_query_heads,
|
num_kv_heads,
|
||||||
)](
|
)](
|
||||||
output_ptr=output,
|
output_ptr=output,
|
||||||
query_ptr=query,
|
query_ptr=query,
|
||||||
@ -269,6 +291,7 @@ def chunked_prefill_paged_decode(
|
|||||||
v_scale=v_scale,
|
v_scale=v_scale,
|
||||||
num_query_heads=num_query_heads,
|
num_query_heads=num_query_heads,
|
||||||
num_queries_per_kv=num_queries_per_kv,
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||||
block_table_stride=block_table.stride(0),
|
block_table_stride=block_table.stride(0),
|
||||||
query_stride_0=query.stride(0),
|
query_stride_0=query.stride(0),
|
||||||
query_stride_1=query.stride(1),
|
query_stride_1=query.stride(1),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user