Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27)

* Add query stride to multi_query_cached_kv_attention

* Add kernel benchmark script
This commit is contained in:
Woosuk Kwon 2023-04-08 13:36:09 -07:00 committed by GitHub
parent 0f40557af6
commit c267b1a02c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 181 additions and 8 deletions

View File

@ -0,0 +1,165 @@
import functools
import random
import time
from typing import List
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
from cacheflow import attention_ops
def benchmark(name, f, num_warmup = 10, num_iters = 100):
for _ in range(num_warmup):
f()
torch.cuda.synchronize()
start = time.time()
for _ in range(num_iters):
f()
torch.cuda.synchronize()
end = time.time()
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')
@torch.inference_mode()
def benchmark_multi_query_cached_kv_attention(
query_lens: List[int],
context_lens: List[int],
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
# Create query tensor.
num_queries = len(query_lens)
cu_query_lens = [0]
for query_len in query_lens:
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
# Create key and value cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
# Create block tables.
max_context_len = max(context_lens)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_queries):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
# Create input and output data structures.
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
# Run our implementation.
def run_ours():
attention_ops.multi_query_cached_kv_attention(
cu_query_lens,
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_len_tensor,
block_size,
max_context_len,
)
benchmark('Ours', run_ours)
# Upper bound: Flash attention.
# Becuase Flash attention cannot read our own cache,
# we make key and value tensors contiguous.
num_kv_tokens = sum(context_lens)
cu_context_lens = [0]
for context_len in context_lens:
cu_context_lens.append(cu_context_lens[-1] + context_len)
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
ref_output = torch.empty_like(output)
# Run Flash attention.
def run_flash_attn():
_flash_attn_forward(
query,
key,
value,
ref_output,
cu_query_lens,
cu_context_lens,
max(query_lens),
max_context_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True,
return_softmax=False,
)
benchmark('Flash attention', run_flash_attn)
if __name__ == '__main__':
BLOCK_SIZE = 8
NUM_BLOCKS = 1024
DTYPE = torch.half
# LLaMA-13B and OPT-13B
NUM_HEADS = 40
HEAD_SIZE = 128
run_benchmark = functools.partial(
benchmark_multi_query_cached_kv_attention,
num_heads=NUM_HEADS,
head_size=HEAD_SIZE,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
dtype=DTYPE,
)
run_benchmark(
query_lens=[64] * 1,
context_lens=[64] * 1,
)
run_benchmark(
query_lens=[128] * 1,
context_lens=[128] * 1,
)
run_benchmark(
query_lens=[64] * 8,
context_lens=[64] * 8,
)
run_benchmark(
query_lens=[128] * 8,
context_lens=[128] * 8,
)
run_benchmark(
query_lens=[64, 32, 16],
context_lens=[128, 256, 64],
)
run_benchmark(
query_lens=[1024],
context_lens=[1024],
)

View File

@ -271,7 +271,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
const float scale,
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
const int context_len,
const int max_num_blocks_per_seq) {
const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
@ -302,7 +303,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
@ -514,7 +516,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_prompts]
const int max_num_blocks_per_seq) {
const int max_num_blocks_per_seq,
const int q_stride) {
const int seq_idx = blockIdx.y;
const int prompt_idx = seq_prompt_mapping[seq_idx];
const int seq_start_idx = cu_query_lens[prompt_idx];
@ -532,7 +535,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
scale,
block_table,
context_len,
max_num_blocks_per_seq);
max_num_blocks_per_seq,
q_stride);
}
} // namespace cacheflow
@ -696,7 +700,8 @@ void single_query_cached_kv_attention(
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq);
max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
@ -719,6 +724,7 @@ void multi_query_cached_kv_attention_launcher(
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();

View File

@ -285,8 +285,9 @@ def test_multi_query_cached_kv_attention(
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]
query = torch.randn(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
@ -314,7 +315,8 @@ def test_multi_query_cached_kv_attention(
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty_like(query)
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
attention_ops.multi_query_cached_kv_attention(
cu_query_lens,