Optimize MQA Kernel (#452)

This commit is contained in:
Zhuohan Li 2023-07-14 20:06:40 -04:00 committed by GitHub
parent dbed69058c
commit 96853af5a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 72 deletions

View File

@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,

View File

@ -74,14 +74,17 @@ template<
__global__ void single_query_cached_kv_attention_kernel( __global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride) { const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
@ -91,6 +94,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
@ -158,8 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ head_idx * HEAD_SIZE * BLOCK_SIZE + kv_head_idx * kv_head_stride
+ physical_block_offset * x; + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
@ -246,8 +250,8 @@ __global__ void single_query_cached_kv_attention_kernel(
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ head_idx * HEAD_SIZE * BLOCK_SIZE; + kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@ -328,12 +332,15 @@ __global__ void single_query_cached_kv_attention_kernel(
query_ptr, \ query_ptr, \
key_cache_ptr, \ key_cache_ptr, \
value_cache_ptr, \ value_cache_ptr, \
head_mapping_ptr, \
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ context_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
alibi_slopes_ptr, \ alibi_slopes_ptr, \
query_stride); q_stride, \
kv_block_stride, \
kv_head_stride);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template<
@ -345,6 +352,7 @@ void single_query_cached_kv_attention_launcher(
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,
@ -354,7 +362,9 @@ void single_query_cached_kv_attention_launcher(
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1); int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0); int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
@ -368,6 +378,7 @@ void single_query_cached_kv_attention_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* context_lens_ptr = context_lens.data_ptr<int>();
@ -422,6 +433,7 @@ void single_query_cached_kv_attention_launcher(
query, \ query, \
key_cache, \ key_cache, \
value_cache, \ value_cache, \
head_mapping, \
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ context_lens, \
@ -469,6 +481,7 @@ void single_query_cached_kv_attention(
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]

View File

@ -94,6 +94,13 @@ class ModelConfig:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode:
if getattr(self.hf_config, "multi_query", False):
# Multi-query attention, only one KV head.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return self.hf_config.n_head_kv
total_num_attention_heads = self.hf_config.num_attention_heads total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size return total_num_attention_heads // parallel_config.tensor_parallel_size

View File

@ -44,12 +44,23 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor. 5. Output a flattened 1D tensor.
""" """
def __init__(self, num_heads: int, head_size: int, scale: float) -> None: def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp() self.attn_op = xops.fmha.cutlass.FwOp()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)
if self.head_size not in _SUPPORTED_HEAD_SIZES: if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
@ -76,10 +87,18 @@ class PagedAttention(nn.Module):
Args: Args:
output: shape = [num_prompt_tokens, num_heads, head_size] output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size] key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size] value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention. input_metadata: metadata for paged attention.
""" """
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
@ -107,9 +126,9 @@ class PagedAttention(nn.Module):
Args: Args:
output: shape = [num_generation_tokens, num_heads, head_size] output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size] query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size]
input_metadata: metadata for paged attention. input_metadata: metadata for paged attention.
""" """
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
@ -118,6 +137,7 @@ class PagedAttention(nn.Module):
query, query,
key_cache, key_cache,
value_cache, value_cache,
self.head_mapping,
self.scale, self.scale,
input_metadata.block_tables, input_metadata.block_tables,
input_metadata.context_lens, input_metadata.context_lens,
@ -143,11 +163,12 @@ class PagedAttention(nn.Module):
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size] value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention. input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish. cache_event: event to wait for the cache operations to finish.
@ -157,8 +178,8 @@ class PagedAttention(nn.Module):
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
# Pre-allocate the output tensor. # Pre-allocate the output tensor.
output = torch.empty_like(query) output = torch.empty_like(query)

View File

@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
import numpy as np
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
@ -55,10 +54,12 @@ class GPTBigCodeAttention(nn.Module):
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.num_kv_heads = 1 if config.multi_query else self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size, self.hidden_size + 2 * self.kv_dim,
bias=True, bias=True,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
@ -69,7 +70,8 @@ class GPTBigCodeAttention(nn.Module):
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale) scale=self.scale,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
@ -79,7 +81,8 @@ class GPTBigCodeAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache, attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event) input_metadata, cache_event)
@ -263,36 +266,6 @@ class GPTBigCodeForCausalLM(nn.Module):
extra_rows = extra_rows.to(loaded_weight) extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
def _expand_mqa_mha(qkv_array, n_head, head_dim):
"""manipulates along axis=0 from MQA to MHA
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
TODO: this function is no longer needed once vllm supports MQA.
"""
qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim
# pylint: disable=unbalanced-tuple-unpacking
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
axis=0)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights
else:
replication = n_head # biases
# replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:
# GPT-2's fused QKV has the shape of # GPT-2's fused QKV has the shape of
@ -300,30 +273,27 @@ class GPTBigCodeForCausalLM(nn.Module):
# When tensor parallelism is used, we shard the weights along # When tensor parallelism is used, we shard the weights along
# the head dimension. # the head dimension.
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"): wq, wk, wv = torch.split(
loaded_weight = _expand_mqa_mha(loaded_weight, loaded_weight, [hidden_size, total_kv_size, total_kv_size],
n_head=total_num_heads, dim=0)
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads, wq = wq[head_size * head_start:head_size * head_end]
head_size, hidden_size) if not self.config.multi_query:
loaded_weight = loaded_weight[:, head_start:head_end, :, :] # Split the heads when using normal multi-head attention
loaded_weight = loaded_weight.reshape(-1, hidden_size) wk = wk[head_size * head_start:head_size * head_end]
elif name.endswith(".bias"): wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = _expand_mqa_mha(loaded_weight, # Else, keep the weights as is for multi-query attention
n_head=total_num_heads,
head_dim=head_size) loaded_weight = torch.cat([wq, wk, wv], dim=0)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,