[Kernel] Make rotary_embedding ops more flexible with input shape (#12777)

This commit is contained in:
Isotr0py 2025-02-07 00:46:13 +08:00 committed by GitHub
parent 1e57b1ee63
commit 85ac82d228
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 57 deletions

View File

@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); // num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int seq_dim_idx = positions_ndim - 1;
int num_kv_heads = key.size(-1) / head_size; int64_t query_stride = query.stride(seq_dim_idx);
int64_t query_stride = query.stride(-2); int64_t key_stride = key.stride(seq_dim_idx);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@ -165,19 +201,58 @@ and process in batched manner.
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim, bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
) { ) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size; TORCH_CHECK(
int num_kv_heads = key.size(-1) / head_size; positions.size(0) == num_tokens || positions.numel() == num_tokens,
int64_t query_stride = query.stride(-2); "positions must have the same num_tokens or batch_size as "
int64_t key_stride = key.stride(-2); "cos_sin_cache_offsets");
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product from itertools import accumulate, product
from typing import Dict, List, Optional from typing import Callable, Dict, List, Optional
import pytest import pytest
import torch import torch
@ -24,7 +24,21 @@ CUDA_DEVICES = [
] ]
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads * head_size)
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@ -36,6 +50,7 @@ CUDA_DEVICES = [
@torch.inference_mode() @torch.inference_mode()
def test_rotary_embedding( def test_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
num_heads: int, num_heads: int,
@ -58,10 +73,8 @@ def test_rotary_embedding(
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
seq_len, query = torch.randn(query_shape, dtype=dtype)
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
@ -80,6 +93,7 @@ def test_rotary_embedding(
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@ -91,6 +105,7 @@ def test_rotary_embedding(
@torch.inference_mode() @torch.inference_mode()
def test_batched_rotary_embedding( def test_batched_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
num_heads: int, num_heads: int,
@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
seq_len, query = torch.randn(query_shape, dtype=dtype)
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first

View File

@ -424,24 +424,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def apply_pure_rope(
self,
input_positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = input_positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe, k_pe = self.rotary_emb(
input_positions,
q_pe.reshape(seq_len, -1),
k_pe.reshape(seq_len, -1),
)
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
return q_pe, k_pe
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
@ -466,14 +448,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding) # Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions") assert hasattr(attn_metadata, "input_positions")
rope_fn = (self.rotary_emb
if self.use_yarn_rope else self.apply_pure_rope)
if is_decode: if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim) .view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
else: else:
assert is_prefill assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\ q = self.q_proj(hidden_states_or_q_c)[0]\
@ -481,7 +462,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# TODO(lucas): there must be a nicer way to write this line # TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \ q[..., self.qk_nope_head_dim:], k_pe = \
rope_fn( self.rotary_emb(
attn_metadata.input_positions, attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe) q[..., self.qk_nope_head_dim:], k_pe)

View File

@ -257,9 +257,7 @@ class DeepseekV2Attention(nn.Module):
prefix=f"{prefix}.o_proj") prefix=f"{prefix}.o_proj")
if rope_scaling: if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn' rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim, self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
@ -309,17 +307,8 @@ class DeepseekV2Attention(nn.Module):
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:] k_pe = latent_cache[:, :, self.kv_lora_rank:]
if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
q[..., self.qk_nope_head_dim:] = q_pe q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q) k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope k[..., :self.qk_nope_head_dim] = k_nope