mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 19:07:56 +08:00
Add batched RoPE kernel (#3095)
This commit is contained in:
parent
ae0ccb4017
commit
7e9bd08f60
120
benchmarks/kernels/benchmark_rope.py
Normal file
120
benchmarks/kernels/benchmark_rope.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import nvtx
|
||||||
|
from itertools import accumulate
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_rope_kernels_multi_lora(
|
||||||
|
is_neox_style: bool,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: Optional[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
max_position: int = 8192,
|
||||||
|
base: int = 10000,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
# silulating serving 4 LoRAs
|
||||||
|
scaling_factors = [1, 2, 4, 8]
|
||||||
|
# batched RoPE can take multiple scaling factors
|
||||||
|
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
|
is_neox_style, {
|
||||||
|
"type": "linear",
|
||||||
|
"factor": tuple(scaling_factors)
|
||||||
|
})
|
||||||
|
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||||
|
# instances to simulate the same behavior
|
||||||
|
non_batched_ropes = []
|
||||||
|
for scaling_factor in scaling_factors:
|
||||||
|
non_batched_ropes.append(
|
||||||
|
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
|
{
|
||||||
|
"type": "linear",
|
||||||
|
"factor": (scaling_factor, )
|
||||||
|
}))
|
||||||
|
|
||||||
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
|
query = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_heads * head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
key = torch.randn_like(query)
|
||||||
|
|
||||||
|
# create query offsets for batched RoPE, we concat multiple kv cache
|
||||||
|
# together and each query needs to find the right kv cache of its type
|
||||||
|
offset_map = torch.tensor(
|
||||||
|
list(
|
||||||
|
accumulate([0] + [
|
||||||
|
max_position * scaling_factor * 2
|
||||||
|
for scaling_factor in scaling_factors[:-1]
|
||||||
|
])))
|
||||||
|
query_types = torch.randint(0,
|
||||||
|
len(scaling_factors), (batch_size, seq_len),
|
||||||
|
device=device)
|
||||||
|
# map query types to offsets
|
||||||
|
query_offsets = offset_map[query_types]
|
||||||
|
# the kernel takes flattened offsets
|
||||||
|
flatten_offsets = query_offsets.flatten()
|
||||||
|
|
||||||
|
# batched queries of the same type together for non-batched RoPE
|
||||||
|
queries = [query[query_types == i] for i in range(len(scaling_factors))]
|
||||||
|
keys = [key[query_types == i] for i in range(len(scaling_factors))]
|
||||||
|
packed_qkr = zip(queries, keys, non_batched_ropes)
|
||||||
|
# synchronize before start timing
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with nvtx.annotate("non-batched", color="yellow"):
|
||||||
|
for q, k, r in packed_qkr:
|
||||||
|
r.forward(positions, q, k)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with nvtx.annotate("batched", color="green"):
|
||||||
|
batched_rope.forward(positions, query, key, flatten_offsets)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark the rotary embedding kernels.")
|
||||||
|
parser.add_argument("--is-neox-style", type=bool, default=True)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=16)
|
||||||
|
parser.add_argument("--seq-len", type=int, default=512)
|
||||||
|
parser.add_argument("--num-heads", type=int, default=8)
|
||||||
|
parser.add_argument("--head-size",
|
||||||
|
type=int,
|
||||||
|
choices=[64, 80, 96, 112, 128, 256],
|
||||||
|
default=128)
|
||||||
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["bfloat16", "float"],
|
||||||
|
default="float")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--device",
|
||||||
|
type=str,
|
||||||
|
choices=["cuda:0", "cuda:1"],
|
||||||
|
default="cuda:0")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
benchmark_rope_kernels_multi_lora(
|
||||||
|
is_neox_style=args.is_neox_style,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
seq_len=args.seq_len,
|
||||||
|
num_heads=args.num_heads,
|
||||||
|
head_size=args.head_size,
|
||||||
|
rotary_dim=args.rotary_dim,
|
||||||
|
dtype=getattr(torch, args.dtype),
|
||||||
|
seed=args.seed,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
10
csrc/ops.h
10
csrc/ops.h
@ -53,6 +53,16 @@ void rotary_embedding(
|
|||||||
torch::Tensor& cos_sin_cache,
|
torch::Tensor& cos_sin_cache,
|
||||||
bool is_neox);
|
bool is_neox);
|
||||||
|
|
||||||
|
void batched_rotary_embedding(
|
||||||
|
torch::Tensor& positions,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key,
|
||||||
|
int head_size,
|
||||||
|
torch::Tensor& cos_sin_cache,
|
||||||
|
bool is_neox,
|
||||||
|
int rot_dim,
|
||||||
|
torch::Tensor& cos_sin_cache_offsets);
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
inline __device__ void apply_rotary_embedding(
|
inline __device__ void apply_token_rotary_embedding(
|
||||||
scalar_t* __restrict__ arr,
|
scalar_t* __restrict__ arr,
|
||||||
const scalar_t* __restrict__ cos_ptr,
|
const scalar_t* __restrict__ cos_ptr,
|
||||||
const scalar_t* __restrict__ sin_ptr,
|
const scalar_t* __restrict__ sin_ptr,
|
||||||
@ -37,6 +37,42 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
arr[y_index] = y * cos + x * sin;
|
arr[y_index] = y * cos + x * sin;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
|
inline __device__ void apply_rotary_embedding(
|
||||||
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||||
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||||
|
const scalar_t* cache_ptr,
|
||||||
|
const int head_size,
|
||||||
|
const int num_heads,
|
||||||
|
const int num_kv_heads,
|
||||||
|
const int rot_dim,
|
||||||
|
const int token_idx,
|
||||||
|
const int64_t query_stride,
|
||||||
|
const int64_t key_stride)
|
||||||
|
{
|
||||||
|
const int embed_dim = rot_dim / 2;
|
||||||
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
|
|
||||||
|
const int nq = num_heads * embed_dim;
|
||||||
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
|
const int head_idx = i / embed_dim;
|
||||||
|
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
|
const int rot_offset = i % embed_dim;
|
||||||
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||||
|
sin_ptr, rot_offset, embed_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int nk = num_kv_heads * embed_dim;
|
||||||
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
|
const int head_idx = i / embed_dim;
|
||||||
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
|
const int rot_offset = i % embed_dim;
|
||||||
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||||
|
sin_ptr, rot_offset, embed_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_kernel(
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
@ -54,27 +90,29 @@ __global__ void rotary_embedding_kernel(
|
|||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
const int embed_dim = rot_dim / 2;
|
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||||
const scalar_t* cos_ptr = cache_ptr;
|
}
|
||||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
|
||||||
|
|
||||||
const int nq = num_heads * embed_dim;
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
__global__ void batched_rotary_embedding_kernel(
|
||||||
const int head_idx = i / embed_dim;
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||||
const int rot_offset = i % embed_dim;
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
sin_ptr, rot_offset, embed_dim);
|
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
||||||
}
|
const int rot_dim,
|
||||||
|
const int64_t query_stride,
|
||||||
|
const int64_t key_stride,
|
||||||
|
const int num_heads,
|
||||||
|
const int num_kv_heads,
|
||||||
|
const int head_size) {
|
||||||
|
// Each thread block is responsible for one token.
|
||||||
|
const int token_idx = blockIdx.x;
|
||||||
|
int64_t pos = positions[token_idx];
|
||||||
|
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||||
|
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||||
|
|
||||||
const int nk = num_kv_heads * embed_dim;
|
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
|
||||||
const int head_idx = i / embed_dim;
|
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
|
||||||
const int rot_offset = i % embed_dim;
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
|
||||||
sin_ptr, rot_offset, embed_dim);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
@ -128,3 +166,61 @@ void rotary_embedding(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Batched version of rotary embedding, pack multiple LoRAs together
|
||||||
|
and process in batched manner.
|
||||||
|
*/
|
||||||
|
void batched_rotary_embedding(
|
||||||
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||||
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||||
|
int head_size,
|
||||||
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
|
bool is_neox,
|
||||||
|
int rot_dim,
|
||||||
|
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||||
|
) {
|
||||||
|
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||||
|
int num_heads = query.size(-1) / head_size;
|
||||||
|
int num_kv_heads = key.size(-1) / head_size;
|
||||||
|
int64_t query_stride = query.stride(-2);
|
||||||
|
int64_t key_stride = key.stride(-2);
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
query.scalar_type(),
|
||||||
|
"rotary_embedding",
|
||||||
|
[&] {
|
||||||
|
if (is_neox) {
|
||||||
|
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||||
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||||
|
rot_dim,
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size);
|
||||||
|
} else {
|
||||||
|
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||||
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||||
|
rot_dim,
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
@ -56,6 +56,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
&rotary_embedding,
|
&rotary_embedding,
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"batched_rotary_embedding",
|
||||||
|
&batched_rotary_embedding,
|
||||||
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from allclose_default import get_default_atol, get_default_rtol
|
from allclose_default import get_default_atol, get_default_rtol
|
||||||
|
from itertools import accumulate
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
IS_NEOX_STYLE = [True, False]
|
IS_NEOX_STYLE = [True, False]
|
||||||
@ -72,3 +73,135 @@ def test_rotary_embedding(
|
|||||||
ref_key,
|
ref_key,
|
||||||
atol=get_default_atol(out_key),
|
atol=get_default_atol(out_key),
|
||||||
rtol=get_default_rtol(out_key))
|
rtol=get_default_rtol(out_key))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_batched_rotary_embedding(
|
||||||
|
is_neox_style: bool,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: Optional[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
max_position: int = 8192,
|
||||||
|
base: int = 10000,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||||
|
"type": "linear",
|
||||||
|
"factor": (1, )
|
||||||
|
})
|
||||||
|
rope = rope.to(dtype=dtype)
|
||||||
|
|
||||||
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
|
query = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_heads * head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
key = torch.randn_like(query)
|
||||||
|
|
||||||
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
|
# because the custom kernel is in-place.
|
||||||
|
ref_query, ref_key = rope._forward(positions, query, key)
|
||||||
|
out_query, out_key = rope.forward(positions,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
offsets=torch.zeros(batch_size * seq_len,
|
||||||
|
dtype=int,
|
||||||
|
device=device))
|
||||||
|
# Compare the results.
|
||||||
|
assert torch.allclose(out_query,
|
||||||
|
ref_query,
|
||||||
|
atol=get_default_atol(out_query),
|
||||||
|
rtol=get_default_rtol(out_query))
|
||||||
|
assert torch.allclose(out_key,
|
||||||
|
ref_key,
|
||||||
|
atol=get_default_atol(out_key),
|
||||||
|
rtol=get_default_rtol(out_key))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_batched_rotary_embedding_multi_lora(
|
||||||
|
is_neox_style: bool,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: Optional[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
max_position: int = 8192,
|
||||||
|
base: int = 10000,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
scaling_factors: List[int] = [1, 2, 4]
|
||||||
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||||
|
"type": "linear",
|
||||||
|
"factor": tuple(scaling_factors)
|
||||||
|
})
|
||||||
|
rope = rope.to(dtype=dtype)
|
||||||
|
|
||||||
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
|
query = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_heads * head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
key = torch.randn_like(query)
|
||||||
|
|
||||||
|
offset_map = torch.tensor(
|
||||||
|
list(
|
||||||
|
accumulate([0] + [
|
||||||
|
max_position * scaling_factor * 2
|
||||||
|
for scaling_factor in scaling_factors[:-1]
|
||||||
|
])))
|
||||||
|
query_types = torch.randint(0,
|
||||||
|
len(scaling_factors), (batch_size, seq_len),
|
||||||
|
device=device)
|
||||||
|
query_offsets = offset_map[query_types]
|
||||||
|
|
||||||
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
|
# because the custom kernel is in-place.
|
||||||
|
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
|
||||||
|
out_query, out_key = rope.forward(positions, query, key,
|
||||||
|
query_offsets.flatten())
|
||||||
|
# Compare the results.
|
||||||
|
assert torch.allclose(out_query,
|
||||||
|
ref_query,
|
||||||
|
atol=get_default_atol(out_query),
|
||||||
|
rtol=get_default_rtol(out_query))
|
||||||
|
assert torch.allclose(out_key,
|
||||||
|
ref_key,
|
||||||
|
atol=get_default_atol(out_key),
|
||||||
|
rtol=get_default_rtol(out_key))
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Rotary Positional Embeddings."""
|
"""Rotary Positional Embeddings."""
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -96,6 +96,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||||
@ -107,7 +108,9 @@ class RotaryEmbedding(nn.Module):
|
|||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
key_pass = key[..., self.rotary_dim:]
|
key_pass = key[..., self.rotary_dim:]
|
||||||
|
|
||||||
cos_sin = self.cos_sin_cache[positions]
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
|
||||||
|
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||||
|
if offsets is not None else positions]
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
if self.is_neox_style:
|
if self.is_neox_style:
|
||||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||||
@ -137,11 +140,19 @@ class RotaryEmbedding(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# ops.rotary_embedding() is an in-place operation that
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
|
||||||
# updates the query and key tensors.
|
# ops.rotary_embedding()/batched_rotary_embedding() are in-place operations that
|
||||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
# update the query and key tensors.
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
if offsets is not None:
|
||||||
|
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||||
|
self.cos_sin_cache,
|
||||||
|
self.is_neox_style, self.rotary_dim,
|
||||||
|
offsets)
|
||||||
|
else:
|
||||||
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
@ -158,27 +169,32 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factors: Union[List[float], float],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
if isinstance(scaling_factors, float):
|
||||||
|
scaling_factors = [scaling_factors]
|
||||||
|
self.scaling_factors = scaling_factors
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style)
|
is_neox_style)
|
||||||
|
|
||||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
inv_freq = self._compute_inv_freq(self.base)
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
cache_list = []
|
||||||
# maximum length before applying the rope scaling.
|
for scaling_factor in self.scaling_factors:
|
||||||
# Thus, the maximum length after applying the rope scaling is
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
# self.max_position_embeddings * self.scaling_factor.
|
# maximum length before applying the rope scaling.
|
||||||
max_len = self.max_position_embeddings * self.scaling_factor
|
# Thus, the maximum length after applying the rope scaling is
|
||||||
t = torch.arange(max_len, dtype=torch.float)
|
# self.max_position_embeddings * self.scaling_factor.
|
||||||
t = t / self.scaling_factor
|
max_len = self.max_position_embeddings * scaling_factor
|
||||||
|
t = torch.arange(max_len, dtype=torch.float)
|
||||||
|
t = t / scaling_factor
|
||||||
|
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
cache_list.append(cache)
|
||||||
|
return torch.cat(cache_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user