mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel (#420)
This commit is contained in:
parent
621980bdc0
commit
79af7e96a0
@ -86,6 +86,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
|
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||||
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int thread_idx = threadIdx.x;
|
const int thread_idx = threadIdx.x;
|
||||||
@ -120,12 +122,13 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
// th vectors of the query, and so on.
|
// th vectors of the query, and so on.
|
||||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||||
|
|
||||||
// Memory planning.
|
// Memory planning.
|
||||||
extern __shared__ char shared_mem[];
|
extern __shared__ char shared_mem[];
|
||||||
@ -173,7 +176,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
// This includes a reduction across the threads in the same thread group.
|
// This includes a reduction across the threads in the same thread group.
|
||||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||||
// Add the ALiBi bias if slopes are given.
|
// Add the ALiBi bias if slopes are given.
|
||||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user