mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:25:01 +08:00
Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27)
* Add query stride to multi_query_cached_kv_attention * Add kernel benchmark script
This commit is contained in:
parent
0f40557af6
commit
c267b1a02c
165
benchmark/benchmark_attention.py
Normal file
165
benchmark/benchmark_attention.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
import functools
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from cacheflow import attention_ops
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(name, f, num_warmup = 10, num_iters = 100):
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
f()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
f()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def benchmark_multi_query_cached_kv_attention(
|
||||||
|
query_lens: List[int],
|
||||||
|
context_lens: List[int],
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
|
||||||
|
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
|
||||||
|
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
|
||||||
|
# Create query tensor.
|
||||||
|
num_queries = len(query_lens)
|
||||||
|
cu_query_lens = [0]
|
||||||
|
for query_len in query_lens:
|
||||||
|
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
||||||
|
num_total_tokens = cu_query_lens[-1]
|
||||||
|
qkv = torch.randn(
|
||||||
|
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
|
query, _, _ = qkv.unbind(dim=1)
|
||||||
|
|
||||||
|
# Create key and value cache.
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||||
|
key_cache = torch.randn(
|
||||||
|
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||||
|
value_block_shape = (num_heads, head_size, block_size)
|
||||||
|
value_cache = torch.randn(
|
||||||
|
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||||
|
|
||||||
|
# Create block tables.
|
||||||
|
max_context_len = max(context_lens)
|
||||||
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||||
|
block_tables = []
|
||||||
|
for _ in range(num_queries):
|
||||||
|
block_table = [
|
||||||
|
random.randint(0, num_blocks - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
|
]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
|
# Create input and output data structures.
|
||||||
|
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
|
||||||
|
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||||
|
scale = float(1.0 / (head_size ** 0.5))
|
||||||
|
output = torch.empty(
|
||||||
|
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
|
|
||||||
|
# Run our implementation.
|
||||||
|
def run_ours():
|
||||||
|
attention_ops.multi_query_cached_kv_attention(
|
||||||
|
cu_query_lens,
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_len_tensor,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
)
|
||||||
|
benchmark('Ours', run_ours)
|
||||||
|
|
||||||
|
# Upper bound: Flash attention.
|
||||||
|
# Becuase Flash attention cannot read our own cache,
|
||||||
|
# we make key and value tensors contiguous.
|
||||||
|
num_kv_tokens = sum(context_lens)
|
||||||
|
cu_context_lens = [0]
|
||||||
|
for context_len in context_lens:
|
||||||
|
cu_context_lens.append(cu_context_lens[-1] + context_len)
|
||||||
|
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
|
||||||
|
qkv = torch.randn(
|
||||||
|
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
|
_, key, value = qkv.unbind(dim=1)
|
||||||
|
ref_output = torch.empty_like(output)
|
||||||
|
|
||||||
|
# Run Flash attention.
|
||||||
|
def run_flash_attn():
|
||||||
|
_flash_attn_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
ref_output,
|
||||||
|
cu_query_lens,
|
||||||
|
cu_context_lens,
|
||||||
|
max(query_lens),
|
||||||
|
max_context_len,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=True,
|
||||||
|
return_softmax=False,
|
||||||
|
)
|
||||||
|
benchmark('Flash attention', run_flash_attn)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
BLOCK_SIZE = 8
|
||||||
|
NUM_BLOCKS = 1024
|
||||||
|
DTYPE = torch.half
|
||||||
|
|
||||||
|
# LLaMA-13B and OPT-13B
|
||||||
|
NUM_HEADS = 40
|
||||||
|
HEAD_SIZE = 128
|
||||||
|
|
||||||
|
run_benchmark = functools.partial(
|
||||||
|
benchmark_multi_query_cached_kv_attention,
|
||||||
|
num_heads=NUM_HEADS,
|
||||||
|
head_size=HEAD_SIZE,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
num_blocks=NUM_BLOCKS,
|
||||||
|
dtype=DTYPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
run_benchmark(
|
||||||
|
query_lens=[64] * 1,
|
||||||
|
context_lens=[64] * 1,
|
||||||
|
)
|
||||||
|
run_benchmark(
|
||||||
|
query_lens=[128] * 1,
|
||||||
|
context_lens=[128] * 1,
|
||||||
|
)
|
||||||
|
run_benchmark(
|
||||||
|
query_lens=[64] * 8,
|
||||||
|
context_lens=[64] * 8,
|
||||||
|
)
|
||||||
|
run_benchmark(
|
||||||
|
query_lens=[128] * 8,
|
||||||
|
context_lens=[128] * 8,
|
||||||
|
)
|
||||||
|
run_benchmark(
|
||||||
|
query_lens=[64, 32, 16],
|
||||||
|
context_lens=[128, 256, 64],
|
||||||
|
)
|
||||||
|
run_benchmark(
|
||||||
|
query_lens=[1024],
|
||||||
|
context_lens=[1024],
|
||||||
|
)
|
||||||
@ -271,7 +271,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
|
|||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int context_len,
|
const int context_len,
|
||||||
const int max_num_blocks_per_seq) {
|
const int max_num_blocks_per_seq,
|
||||||
|
const int q_stride) {
|
||||||
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int thread_idx = threadIdx.x;
|
const int thread_idx = threadIdx.x;
|
||||||
@ -302,7 +303,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
|
|||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
// For example, if the the thread group size is 4, then the first thread in the group
|
||||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||||
// th vectors of the query, and so on.
|
// th vectors of the query, and so on.
|
||||||
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||||
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||||
@ -514,7 +516,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
|
|||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ context_lens, // [num_prompts]
|
const int* __restrict__ context_lens, // [num_prompts]
|
||||||
const int max_num_blocks_per_seq) {
|
const int max_num_blocks_per_seq,
|
||||||
|
const int q_stride) {
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const int prompt_idx = seq_prompt_mapping[seq_idx];
|
const int prompt_idx = seq_prompt_mapping[seq_idx];
|
||||||
const int seq_start_idx = cu_query_lens[prompt_idx];
|
const int seq_start_idx = cu_query_lens[prompt_idx];
|
||||||
@ -532,7 +535,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
|
|||||||
scale,
|
scale,
|
||||||
block_table,
|
block_table,
|
||||||
context_len,
|
context_len,
|
||||||
max_num_blocks_per_seq);
|
max_num_blocks_per_seq,
|
||||||
|
q_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace cacheflow
|
||||||
@ -696,7 +700,8 @@ void single_query_cached_kv_attention(
|
|||||||
scale, \
|
scale, \
|
||||||
block_tables_ptr, \
|
block_tables_ptr, \
|
||||||
context_lens_ptr, \
|
context_lens_ptr, \
|
||||||
max_num_blocks_per_seq);
|
max_num_blocks_per_seq, \
|
||||||
|
query_stride);
|
||||||
|
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
@ -719,6 +724,7 @@ void multi_query_cached_kv_attention_launcher(
|
|||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
int max_num_blocks_per_seq = block_tables.size(1);
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
int query_stride = query.stride(0);
|
||||||
|
|
||||||
int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
|
int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
|
||||||
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
|
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
|
||||||
|
|||||||
@ -285,8 +285,9 @@ def test_multi_query_cached_kv_attention(
|
|||||||
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
||||||
num_total_tokens = cu_query_lens[-1]
|
num_total_tokens = cu_query_lens[-1]
|
||||||
|
|
||||||
query = torch.randn(
|
qkv = torch.randn(
|
||||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
|
query, _, _ = qkv.unbind(dim=1)
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||||
key_cache = torch.randn(
|
key_cache = torch.randn(
|
||||||
@ -314,7 +315,8 @@ def test_multi_query_cached_kv_attention(
|
|||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
scale = float(1.0 / (head_size ** 0.5))
|
scale = float(1.0 / (head_size ** 0.5))
|
||||||
output = torch.empty_like(query)
|
output = torch.empty(
|
||||||
|
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
|
|
||||||
attention_ops.multi_query_cached_kv_attention(
|
attention_ops.multi_query_cached_kv_attention(
|
||||||
cu_query_lens,
|
cu_query_lens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user