mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 14:27:06 +08:00
[Kernel] Have rotary embeddings support tensors (#18046)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
749f792553
commit
d93c976a0d
@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
// head_size]
|
// head_size]
|
||||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||||
const int num_kv_heads, const int rot_dim, const int token_idx,
|
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||||
const int64_t query_stride, const int64_t key_stride) {
|
const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int64_t head_stride) {
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const scalar_t* cos_ptr = cache_ptr;
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int nq = num_heads * embed_dim;
|
const int nq = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
const int64_t token_head =
|
||||||
|
token_idx * query_stride + head_idx * head_stride;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int nk = num_kv_heads * embed_dim;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head =
|
||||||
|
token_idx * key_stride + head_idx * head_stride;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel(
|
|||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
// 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
||||||
|
const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel(
|
|||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
token_idx, query_stride, key_stride);
|
token_idx, query_stride, key_stride, head_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
|
|||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
// 2]
|
// 2]
|
||||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||||
// or [num_tokens]
|
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
||||||
|
const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
|
|||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
token_idx, query_stride, key_stride);
|
token_idx, query_stride, key_stride, head_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
@ -179,6 +183,12 @@ void rotary_embedding(
|
|||||||
int seq_dim_idx = positions_ndim - 1;
|
int seq_dim_idx = positions_ndim - 1;
|
||||||
int64_t query_stride = query.stride(seq_dim_idx);
|
int64_t query_stride = query.stride(seq_dim_idx);
|
||||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||||
|
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
||||||
|
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
||||||
|
// head_size
|
||||||
|
int query_ndim = query.dim();
|
||||||
|
int64_t head_stride =
|
||||||
|
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
@ -190,14 +200,14 @@ void rotary_embedding(
|
|||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
|
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
|
||||||
num_heads, num_kv_heads, head_size);
|
head_stride, num_heads, num_kv_heads, head_size);
|
||||||
} else {
|
} else {
|
||||||
vllm::rotary_embedding_kernel<scalar_t, false>
|
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
||||||
key_stride, num_heads, num_kv_heads, head_size);
|
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -263,6 +273,12 @@ void batched_rotary_embedding(
|
|||||||
int seq_dim_idx = positions_ndim - 1;
|
int seq_dim_idx = positions_ndim - 1;
|
||||||
int64_t query_stride = query.stride(seq_dim_idx);
|
int64_t query_stride = query.stride(seq_dim_idx);
|
||||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||||
|
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
||||||
|
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
||||||
|
// head_size
|
||||||
|
int query_ndim = query.dim();
|
||||||
|
int64_t head_stride =
|
||||||
|
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
@ -276,7 +292,7 @@ void batched_rotary_embedding(
|
|||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
key_stride, num_heads, num_kv_heads, head_size);
|
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||||
} else {
|
} else {
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
@ -284,7 +300,7 @@ void batched_rotary_embedding(
|
|||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
key_stride, num_heads, num_kv_heads, head_size);
|
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
|||||||
return (batch_size, seq_len, num_heads * head_size)
|
return (batch_size, seq_len, num_heads * head_size)
|
||||||
|
|
||||||
|
|
||||||
|
# For testing sliced tensors
|
||||||
|
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||||
|
head_size: int) -> tuple[int, ...]:
|
||||||
|
return (batch_size, seq_len, num_heads, head_size + 64)
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||||
head_size: int) -> tuple[int, ...]:
|
head_size: int) -> tuple[int, ...]:
|
||||||
return (batch_size, seq_len, num_heads, head_size)
|
return (batch_size, seq_len, num_heads, head_size)
|
||||||
|
|
||||||
|
|
||||||
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
|
TENSORS_SHAPES_FN = [
|
||||||
|
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
@ -79,6 +87,10 @@ def test_rotary_embedding(
|
|||||||
query = torch.randn(query_shape, dtype=dtype)
|
query = torch.randn(query_shape, dtype=dtype)
|
||||||
key = torch.randn_like(query) if use_key else None
|
key = torch.randn_like(query) if use_key else None
|
||||||
|
|
||||||
|
# slice tensor if required, noop otherwise
|
||||||
|
query = query[..., :head_size]
|
||||||
|
key = key[..., :head_size] if use_key else None
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
# because the custom kernel is in-place.
|
# because the custom kernel is in-place.
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||||
|
|||||||
@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot,
|
|||||||
@pytest.mark.parametrize("head_size", [32, 108])
|
@pytest.mark.parametrize("head_size", [32, 108])
|
||||||
@pytest.mark.parametrize("seq_len", [11, 1024])
|
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||||
@pytest.mark.parametrize("use_key", [True, False])
|
@pytest.mark.parametrize("use_key", [True, False])
|
||||||
|
@pytest.mark.parametrize("head_stride_is_contingous", [True, False])
|
||||||
def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
||||||
is_neox_style, rotary_dim, head_size,
|
is_neox_style, rotary_dim, head_size,
|
||||||
seq_len, use_key):
|
seq_len, use_key, head_stride_is_contingous):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
base = 10000
|
base = 10000
|
||||||
num_heads = 7
|
num_heads = 7
|
||||||
@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
|||||||
positions = torch.randint(0,
|
positions = torch.randint(0,
|
||||||
max_position, (batch_size, seq_len),
|
max_position, (batch_size, seq_len),
|
||||||
device=device)
|
device=device)
|
||||||
|
head_stride = head_size + (64 if head_stride_is_contingous else 0)
|
||||||
|
|
||||||
query = torch.randn(batch_size,
|
query = torch.randn(batch_size,
|
||||||
seq_len,
|
seq_len,
|
||||||
num_heads * head_size,
|
num_heads,
|
||||||
|
head_stride,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device)
|
device=device)
|
||||||
key = torch.randn_like(query) if use_key else None
|
key = torch.randn_like(query) if use_key else None
|
||||||
|
query = query[..., :head_size]
|
||||||
|
key = key[..., :head_size] if use_key else None
|
||||||
|
|
||||||
rotary_embedding_opcheck(rot, positions, query, key)
|
rotary_embedding_opcheck(rot, positions, query, key)
|
||||||
offsets = torch.zeros(batch_size * seq_len,
|
offsets = torch.zeros(batch_size * seq_len,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
rotary_embedding_opcheck(rot, positions, query, key, offsets)
|
rotary_embedding_opcheck(rot, positions, query, key, offsets)
|
||||||
|
|
||||||
|
# if we have a contiguous head stride, test the alternate
|
||||||
|
# [..., num_heads * head_dim] shape/layout
|
||||||
|
if head_stride_is_contingous:
|
||||||
|
rotary_embedding_opcheck(
|
||||||
|
rot, positions, query.flatten(start_dim=-2),
|
||||||
|
key.flatten(start_dim=-2) if use_key else None)
|
||||||
|
|||||||
@ -254,14 +254,8 @@ def rotary_embedding(
|
|||||||
cos_sin_cache: torch.Tensor,
|
cos_sin_cache: torch.Tensor,
|
||||||
is_neox: bool,
|
is_neox: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
torch.ops._C.rotary_embedding(positions, query, key, head_size,
|
||||||
query_contiguous = query.contiguous()
|
cos_sin_cache, is_neox)
|
||||||
key_contiguous = key.contiguous() if key is not None else None
|
|
||||||
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
|
|
||||||
head_size, cos_sin_cache, is_neox)
|
|
||||||
query.copy_(query_contiguous)
|
|
||||||
if key is not None:
|
|
||||||
key.copy_(key_contiguous)
|
|
||||||
|
|
||||||
|
|
||||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||||
@ -269,16 +263,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|||||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||||
rot_dim: int,
|
rot_dim: int,
|
||||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||||
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
||||||
query_contiguous = query.contiguous()
|
|
||||||
key_contiguous = key.contiguous() if key is not None else None
|
|
||||||
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
|
|
||||||
key_contiguous, head_size,
|
|
||||||
cos_sin_cache, is_neox, rot_dim,
|
cos_sin_cache, is_neox, rot_dim,
|
||||||
cos_sin_cache_offsets)
|
cos_sin_cache_offsets)
|
||||||
query.copy_(query_contiguous)
|
|
||||||
if key is not None:
|
|
||||||
key.copy_(key_contiguous)
|
|
||||||
|
|
||||||
|
|
||||||
# layer norm ops
|
# layer norm ops
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user