From 79af7e96a0e2fc9f340d1939192122c3ae38ff17 Mon Sep 17 00:00:00 2001 From: Dean Leitersdorf Date: Fri, 4 Aug 2023 20:57:29 +0300 Subject: [PATCH] [OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel (#420) --- csrc/attention/attention_kernels.cu | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 930220410738..568d1fb1ad24 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -86,6 +86,8 @@ __global__ void single_query_cached_kv_attention_kernel( const int kv_block_stride, const int kv_head_stride) { constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; @@ -120,12 +122,13 @@ __global__ void single_query_cached_kv_attention_kernel( // th vectors of the query, and so on. // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - Q_vec q_vecs[NUM_VECS_PER_THREAD]; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -173,7 +176,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;