mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:35:01 +08:00
Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27)
* Add query stride to multi_query_cached_kv_attention * Add kernel benchmark script
This commit is contained in:
parent
0f40557af6
commit
c267b1a02c
165
benchmark/benchmark_attention.py
Normal file
165
benchmark/benchmark_attention.py
Normal file
@ -0,0 +1,165 @@
|
||||
import functools
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
|
||||
from cacheflow import attention_ops
|
||||
|
||||
|
||||
def benchmark(name, f, num_warmup = 10, num_iters = 100):
|
||||
for _ in range(num_warmup):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_multi_query_cached_kv_attention(
|
||||
query_lens: List[int],
|
||||
context_lens: List[int],
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
|
||||
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
|
||||
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
|
||||
# Create query tensor.
|
||||
num_queries = len(query_lens)
|
||||
cu_query_lens = [0]
|
||||
for query_len in query_lens:
|
||||
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
||||
num_total_tokens = cu_query_lens[-1]
|
||||
qkv = torch.randn(
|
||||
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
# Create key and value cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
# Create block tables.
|
||||
max_context_len = max(context_lens)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_queries):
|
||||
block_table = [
|
||||
random.randint(0, num_blocks - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
# Create input and output data structures.
|
||||
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
|
||||
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty(
|
||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
|
||||
# Run our implementation.
|
||||
def run_ours():
|
||||
attention_ops.multi_query_cached_kv_attention(
|
||||
cu_query_lens,
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_len_tensor,
|
||||
block_size,
|
||||
max_context_len,
|
||||
)
|
||||
benchmark('Ours', run_ours)
|
||||
|
||||
# Upper bound: Flash attention.
|
||||
# Becuase Flash attention cannot read our own cache,
|
||||
# we make key and value tensors contiguous.
|
||||
num_kv_tokens = sum(context_lens)
|
||||
cu_context_lens = [0]
|
||||
for context_len in context_lens:
|
||||
cu_context_lens.append(cu_context_lens[-1] + context_len)
|
||||
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
|
||||
qkv = torch.randn(
|
||||
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
ref_output = torch.empty_like(output)
|
||||
|
||||
# Run Flash attention.
|
||||
def run_flash_attn():
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
ref_output,
|
||||
cu_query_lens,
|
||||
cu_context_lens,
|
||||
max(query_lens),
|
||||
max_context_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
return_softmax=False,
|
||||
)
|
||||
benchmark('Flash attention', run_flash_attn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
BLOCK_SIZE = 8
|
||||
NUM_BLOCKS = 1024
|
||||
DTYPE = torch.half
|
||||
|
||||
# LLaMA-13B and OPT-13B
|
||||
NUM_HEADS = 40
|
||||
HEAD_SIZE = 128
|
||||
|
||||
run_benchmark = functools.partial(
|
||||
benchmark_multi_query_cached_kv_attention,
|
||||
num_heads=NUM_HEADS,
|
||||
head_size=HEAD_SIZE,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_blocks=NUM_BLOCKS,
|
||||
dtype=DTYPE,
|
||||
)
|
||||
|
||||
run_benchmark(
|
||||
query_lens=[64] * 1,
|
||||
context_lens=[64] * 1,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[128] * 1,
|
||||
context_lens=[128] * 1,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[64] * 8,
|
||||
context_lens=[64] * 8,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[128] * 8,
|
||||
context_lens=[128] * 8,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[64, 32, 16],
|
||||
context_lens=[128, 256, 64],
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[1024],
|
||||
context_lens=[1024],
|
||||
)
|
||||
@ -271,7 +271,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
|
||||
const float scale,
|
||||
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int context_len,
|
||||
const int max_num_blocks_per_seq) {
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride) {
|
||||
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
@ -302,7 +303,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||
// th vectors of the query, and so on.
|
||||
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||
@ -514,7 +516,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_prompts]
|
||||
const int max_num_blocks_per_seq) {
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int prompt_idx = seq_prompt_mapping[seq_idx];
|
||||
const int seq_start_idx = cu_query_lens[prompt_idx];
|
||||
@ -532,7 +535,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
|
||||
scale,
|
||||
block_table,
|
||||
context_len,
|
||||
max_num_blocks_per_seq);
|
||||
max_num_blocks_per_seq,
|
||||
q_stride);
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
@ -696,7 +700,8 @@ void single_query_cached_kv_attention(
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq);
|
||||
max_num_blocks_per_seq, \
|
||||
query_stride);
|
||||
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
@ -719,6 +724,7 @@ void multi_query_cached_kv_attention_launcher(
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int query_stride = query.stride(0);
|
||||
|
||||
int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
|
||||
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
|
||||
|
||||
@ -285,8 +285,9 @@ def test_multi_query_cached_kv_attention(
|
||||
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
||||
num_total_tokens = cu_query_lens[-1]
|
||||
|
||||
query = torch.randn(
|
||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(
|
||||
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
@ -314,7 +315,8 @@ def test_multi_query_cached_kv_attention(
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty_like(query)
|
||||
output = torch.empty(
|
||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
|
||||
attention_ops.multi_query_cached_kv_attention(
|
||||
cu_query_lens,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user