[BugFix] Implement RoPE for GPT-J (#941)

This commit is contained in:
Woosuk Kwon 2023-09-06 11:54:33 +09:00 committed by GitHub
parent c9927c1a6a
commit 320a622ec4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 72 deletions

View File

@ -1,15 +1,16 @@
#include <torch/extension.h>
void rotary_embedding_neox(
void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);
torch::Tensor& cos_sin_cache,
bool is_neox);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
}

View File

@ -5,8 +5,38 @@
namespace vllm {
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
@ -23,58 +53,37 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
}
} // namespace vllm
void rotary_embedding_neox(
void rotary_embedding(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
@ -87,18 +96,32 @@ void rotary_embedding_neox(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding_neox",
"rotary_embedding",
[&] {
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}

View File

@ -7,49 +7,64 @@ import torch.nn.functional as F
from vllm import pos_encoding_ops
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
SEEDS = [0]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
q_embed = (q * cos) + (rotate_fn(q) * sin)
k_embed = (k * cos) + (rotate_fn(k) * sin)
return q_embed, k_embed
class RefRotaryEmbeddingNeox(nn.Module):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
class RefRotaryEmbedding(nn.Module):
"""Reference implementation of rotary embedding."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
is_neox_style: bool,
max_position_embeddings: int = 8192,
base: int = 10000,
) -> None:
super().__init__()
self.rotary_dim = dim
self.is_neox_style = is_neox_style
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
if is_neox_style:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.repeat_interleave(freqs, 2, -1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
@ -61,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_rot = key[..., :self.rotary_dim]
@ -71,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module):
key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
self.is_neox_style)
query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous()
@ -82,6 +98,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
return query, key
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -89,7 +106,8 @@ class RefRotaryEmbeddingNeox(nn.Module):
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_rotary_embedding_neox(
def test_rotary_embedding(
is_neox_style: bool,
num_tokens: int,
num_heads: int,
head_size: int,
@ -104,15 +122,15 @@ def test_rotary_embedding_neox(
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
query = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
device="cuda")
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
device="cuda")
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
@ -126,20 +144,22 @@ def test_rotary_embedding_neox(
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
pos_encoding_ops.rotary_embedding_neox(
pos_encoding_ops.rotary_embedding(
positions,
out_query,
out_key,
head_size,
cos_sin_cache,
is_neox_style,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox(
ref_rotary_embedding = RefRotaryEmbedding(
dim=rotary_dim,
is_neox_style=is_neox_style,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device='cuda')
).to(dtype=dtype, device="cuda")
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),

View File

@ -242,7 +242,7 @@ class PagedAttention(nn.Module):
class PagedAttentionWithRoPE(PagedAttention):
"""PagedAttention with GPT-NeoX style rotary embedding."""
"""PagedAttention with rotary embedding."""
def __init__(
self,
@ -253,8 +253,10 @@ class PagedAttentionWithRoPE(PagedAttention):
max_position: int = 8192,
base: int = 10000,
num_kv_heads: Optional[int] = None,
is_neox_style: bool = True,
) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
self.is_neox_style = is_neox_style
# Create the cos and sin cache.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
@ -303,12 +305,13 @@ class PagedAttentionWithRoPE(PagedAttention):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops.rotary_embedding_neox(
pos_encoding_ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return super().forward(
query,

View File

@ -67,8 +67,11 @@ class GPTJAttention(nn.Module):
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
scaling, config.rotary_dim)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_size,
scaling,
config.rotary_dim,
is_neox_style=False)
self.warmup = False
def forward(