mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:18:39 +08:00
Optimize MQA Kernel (#452)
This commit is contained in:
parent
dbed69058c
commit
96853af5a8
@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
|
|||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
|
|||||||
@ -74,14 +74,17 @@ template<
|
|||||||
__global__ void single_query_cached_kv_attention_kernel(
|
__global__ void single_query_cached_kv_attention_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
const int* __restrict__ head_mapping, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride) {
|
const int q_stride,
|
||||||
|
const int kv_block_stride,
|
||||||
|
const int kv_head_stride) {
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
@ -91,6 +94,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
|
|
||||||
const int head_idx = blockIdx.x;
|
const int head_idx = blockIdx.x;
|
||||||
const int num_heads = gridDim.x;
|
const int num_heads = gridDim.x;
|
||||||
|
const int kv_head_idx = head_mapping[head_idx];
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||||
|
|
||||||
@ -158,8 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE
|
+ kv_head_idx * kv_head_stride
|
||||||
+ physical_block_offset * x;
|
+ physical_block_offset * x;
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
@ -246,8 +250,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
L_vec logits_vec;
|
L_vec logits_vec;
|
||||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
||||||
|
|
||||||
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
|
+ kv_head_idx * kv_head_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
@ -328,12 +332,15 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
value_cache_ptr, \
|
value_cache_ptr, \
|
||||||
|
head_mapping_ptr, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables_ptr, \
|
block_tables_ptr, \
|
||||||
context_lens_ptr, \
|
context_lens_ptr, \
|
||||||
max_num_blocks_per_seq, \
|
max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, \
|
alibi_slopes_ptr, \
|
||||||
query_stride);
|
q_stride, \
|
||||||
|
kv_block_stride, \
|
||||||
|
kv_head_stride);
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template<
|
||||||
@ -345,6 +352,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
@ -354,7 +362,9 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
int max_num_blocks_per_seq = block_tables.size(1);
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
int query_stride = query.stride(0);
|
int q_stride = query.stride(0);
|
||||||
|
int kv_block_stride = key_cache.stride(0);
|
||||||
|
int kv_head_stride = key_cache.stride(1);
|
||||||
|
|
||||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
assert(head_size % thread_group_size == 0);
|
assert(head_size % thread_group_size == 0);
|
||||||
@ -368,6 +378,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||||
|
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@ -422,6 +433,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
|
head_mapping, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
@ -469,6 +481,7 @@ void single_query_cached_kv_attention(
|
|||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
|
torch::Tensor& head_mapping, // [num_heads]
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
|
|||||||
@ -94,6 +94,13 @@ class ModelConfig:
|
|||||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||||
|
|
||||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
# For GPTBigCode:
|
||||||
|
if getattr(self.hf_config, "multi_query", False):
|
||||||
|
# Multi-query attention, only one KV head.
|
||||||
|
return 1
|
||||||
|
# For Falcon:
|
||||||
|
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||||
|
return self.hf_config.n_head_kv
|
||||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
|||||||
@ -44,12 +44,23 @@ class PagedAttention(nn.Module):
|
|||||||
5. Output a flattened 1D tensor.
|
5. Output a flattened 1D tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.head_mapping = torch.repeat_interleave(
|
||||||
|
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||||
|
self.num_queries_per_kv)
|
||||||
|
|
||||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||||
@ -76,10 +87,18 @@ class PagedAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
key: shape = [num_prompt_tokens, num_heads, head_size]
|
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_prompt_tokens, num_heads, head_size]
|
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
# Project the key and value tensors to the desired number of heads.
|
||||||
|
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value,
|
||||||
|
self.num_queries_per_kv,
|
||||||
|
dim=1)
|
||||||
|
|
||||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||||
out = xops.memory_efficient_attention_forward(
|
out = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query.unsqueeze(0),
|
||||||
@ -107,9 +126,9 @@ class PagedAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
@ -118,6 +137,7 @@ class PagedAttention(nn.Module):
|
|||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
|
self.head_mapping,
|
||||||
self.scale,
|
self.scale,
|
||||||
input_metadata.block_tables,
|
input_metadata.block_tables,
|
||||||
input_metadata.context_lens,
|
input_metadata.context_lens,
|
||||||
@ -143,11 +163,12 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
cache_event: event to wait for the cache operations to finish.
|
cache_event: event to wait for the cache operations to finish.
|
||||||
|
|
||||||
@ -157,8 +178,8 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
# Pre-allocate the output tensor.
|
# Pre-allocate the output tensor.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|||||||
@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
|
||||||
from transformers import GPTBigCodeConfig
|
from transformers import GPTBigCodeConfig
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
@ -55,10 +54,12 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||||
self.head_dim = self.hidden_size // total_num_heads
|
self.head_dim = self.hidden_size // total_num_heads
|
||||||
|
self.num_kv_heads = 1 if config.multi_query else self.num_heads
|
||||||
|
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||||
3 * self.hidden_size,
|
self.hidden_size + 2 * self.kv_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
@ -69,7 +70,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.attn = PagedAttention(self.num_heads,
|
self.attn = PagedAttention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale)
|
scale=self.scale,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -79,7 +81,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
|
||||||
|
dim=-1)
|
||||||
key_cache, value_cache = kv_cache
|
key_cache, value_cache = kv_cache
|
||||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||||
input_metadata, cache_event)
|
input_metadata, cache_event)
|
||||||
@ -263,36 +266,6 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
extra_rows = extra_rows.to(loaded_weight)
|
extra_rows = extra_rows.to(loaded_weight)
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||||
|
|
||||||
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
|
||||||
"""manipulates along axis=0 from MQA to MHA
|
|
||||||
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
|
||||||
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
|
|
||||||
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
|
|
||||||
|
|
||||||
TODO: this function is no longer needed once vllm supports MQA.
|
|
||||||
"""
|
|
||||||
qkv_array = qkv_array.numpy()
|
|
||||||
|
|
||||||
dims_q = n_head * head_dim
|
|
||||||
# pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
|
|
||||||
axis=0)
|
|
||||||
# q is fine, but k & v have not replicated shape along the first
|
|
||||||
# axis as long as MQA is not nativly supported, increase memory
|
|
||||||
# and replicated (head_dim, hidden_dim) to
|
|
||||||
# (n_heads * head_dim, hidden_dim)
|
|
||||||
if k.ndim == 2 and v.ndim == 2:
|
|
||||||
replication = (n_head, 1) # weights
|
|
||||||
else:
|
|
||||||
replication = n_head # biases
|
|
||||||
# replicate n_head times for q, v
|
|
||||||
k, v = np.tile(k, replication), np.tile(v, replication)
|
|
||||||
# concat q, k, v along the first axis
|
|
||||||
# (n_heads * head_dim, hidden_dim)
|
|
||||||
# to (3 * n_heads * head_dim, hidden_dim)
|
|
||||||
qkv_array = np.concatenate((q, k, v), axis=0)
|
|
||||||
return torch.from_numpy(qkv_array)
|
|
||||||
|
|
||||||
# For the fused QKV linear layer, manually shard the weights.
|
# For the fused QKV linear layer, manually shard the weights.
|
||||||
if "c_attn" in name:
|
if "c_attn" in name:
|
||||||
# GPT-2's fused QKV has the shape of
|
# GPT-2's fused QKV has the shape of
|
||||||
@ -300,30 +273,27 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
# When tensor parallelism is used, we shard the weights along
|
# When tensor parallelism is used, we shard the weights along
|
||||||
# the head dimension.
|
# the head dimension.
|
||||||
total_num_heads = self.config.num_attention_heads
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
total_num_kv_heads = (1 if self.config.multi_query else
|
||||||
|
total_num_heads)
|
||||||
hidden_size = self.config.hidden_size
|
hidden_size = self.config.hidden_size
|
||||||
head_size = hidden_size // total_num_heads
|
head_size = hidden_size // total_num_heads
|
||||||
|
total_kv_size = head_size * total_num_kv_heads
|
||||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||||
head_start = tensor_model_parallel_rank * num_heads
|
head_start = tensor_model_parallel_rank * num_heads
|
||||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||||
|
|
||||||
if name.endswith(".weight"):
|
wq, wk, wv = torch.split(
|
||||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||||
n_head=total_num_heads,
|
dim=0)
|
||||||
head_dim=head_size)
|
|
||||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
wq = wq[head_size * head_start:head_size * head_end]
|
||||||
head_size, hidden_size)
|
if not self.config.multi_query:
|
||||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
# Split the heads when using normal multi-head attention
|
||||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
wk = wk[head_size * head_start:head_size * head_end]
|
||||||
elif name.endswith(".bias"):
|
wv = wv[head_size * head_start:head_size * head_end]
|
||||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
# Else, keep the weights as is for multi-query attention
|
||||||
n_head=total_num_heads,
|
|
||||||
head_dim=head_size)
|
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
|
||||||
head_size)
|
|
||||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected parameter name {name}")
|
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
self._column_parallel_weights,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user