mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Kernel] Make rotary_embedding ops more flexible with input shape (#12777)
This commit is contained in:
parent
1e57b1ee63
commit
85ac82d228
@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
|
|||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
// [num_tokens, num_heads * head_size]
|
// [num_tokens, num_heads * head_size] or
|
||||||
|
// [batch_size, seq_len, num_heads, head_size] or
|
||||||
|
// [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
// [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_kv_heads * head_size] or
|
||||||
|
// [batch_size, seq_len, num_heads, head_size] or
|
||||||
|
// [num_tokens, num_heads, head_size]
|
||||||
int64_t head_size,
|
int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
int64_t num_tokens = query.numel() / query.size(-1);
|
// num_tokens = batch_size * seq_len
|
||||||
|
int64_t num_tokens = positions.numel();
|
||||||
|
int positions_ndim = positions.dim();
|
||||||
|
|
||||||
|
// Make sure num_tokens dim is consistent across positions, query, and key.
|
||||||
|
TORCH_CHECK(
|
||||||
|
positions_ndim == 1 || positions_ndim == 2,
|
||||||
|
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||||
|
if (positions_ndim == 1) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||||
|
"query, key and positions must have the same number of tokens");
|
||||||
|
}
|
||||||
|
if (positions_ndim == 2) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
query.size(0) == positions.size(0) &&
|
||||||
|
key.size(0) == positions.size(0) &&
|
||||||
|
query.size(1) == positions.size(1) &&
|
||||||
|
key.size(1) == positions.size(1),
|
||||||
|
"query, key and positions must have the same batch_size and seq_len");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure head_size is valid for query and key
|
||||||
|
// hidden_size = num_heads * head_size
|
||||||
|
int query_hidden_size = query.numel() / num_tokens;
|
||||||
|
int key_hidden_size = key.numel() / num_tokens;
|
||||||
|
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||||
|
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||||
|
|
||||||
|
// Make sure query and key have consistent number of heads
|
||||||
|
int num_heads = query_hidden_size / head_size;
|
||||||
|
int num_kv_heads = key_hidden_size / head_size;
|
||||||
|
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||||
|
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(-1) / head_size;
|
int seq_dim_idx = positions_ndim - 1;
|
||||||
int num_kv_heads = key.size(-1) / head_size;
|
int64_t query_stride = query.stride(seq_dim_idx);
|
||||||
int64_t query_stride = query.stride(-2);
|
int64_t key_stride = key.stride(seq_dim_idx);
|
||||||
int64_t key_stride = key.stride(-2);
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
@ -165,19 +201,58 @@ and process in batched manner.
|
|||||||
void batched_rotary_embedding(
|
void batched_rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
// [num_tokens, num_heads * head_size]
|
// [num_tokens, num_heads * head_size] or
|
||||||
|
// [batch_size, seq_len, num_heads, head_size] or
|
||||||
|
// [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
// [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_kv_heads * head_size] or
|
||||||
|
// [batch_size, seq_len, num_heads, head_size] or
|
||||||
|
// [num_tokens, num_heads, head_size]
|
||||||
int64_t head_size,
|
int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox, int64_t rot_dim,
|
bool is_neox, int64_t rot_dim,
|
||||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
|
||||||
) {
|
) {
|
||||||
|
// num_tokens = batch_size * seq_len
|
||||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||||
int num_heads = query.size(-1) / head_size;
|
TORCH_CHECK(
|
||||||
int num_kv_heads = key.size(-1) / head_size;
|
positions.size(0) == num_tokens || positions.numel() == num_tokens,
|
||||||
int64_t query_stride = query.stride(-2);
|
"positions must have the same num_tokens or batch_size as "
|
||||||
int64_t key_stride = key.stride(-2);
|
"cos_sin_cache_offsets");
|
||||||
|
|
||||||
|
int positions_ndim = positions.dim();
|
||||||
|
// Make sure num_tokens dim is consistent across positions, query, and key.
|
||||||
|
TORCH_CHECK(
|
||||||
|
positions_ndim == 1 || positions_ndim == 2,
|
||||||
|
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||||
|
if (positions_ndim == 1) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||||
|
"query, key and positions must have the same number of tokens");
|
||||||
|
}
|
||||||
|
if (positions_ndim == 2) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
query.size(0) == positions.size(0) &&
|
||||||
|
key.size(0) == positions.size(0) &&
|
||||||
|
query.size(1) == positions.size(1) &&
|
||||||
|
key.size(1) == positions.size(1),
|
||||||
|
"query, key and positions must have the same batch_size and seq_len");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure head_size is valid for query and key
|
||||||
|
int query_hidden_size = query.numel() / num_tokens;
|
||||||
|
int key_hidden_size = key.numel() / num_tokens;
|
||||||
|
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||||
|
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||||
|
|
||||||
|
// Make sure query and key have concistent number of heads
|
||||||
|
int num_heads = query_hidden_size / head_size;
|
||||||
|
int num_kv_heads = key_hidden_size / head_size;
|
||||||
|
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||||
|
|
||||||
|
int seq_dim_idx = positions_ndim - 1;
|
||||||
|
int64_t query_stride = query.stride(seq_dim_idx);
|
||||||
|
int64_t key_stride = key.stride(seq_dim_idx);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from itertools import accumulate, product
|
from itertools import accumulate, product
|
||||||
from typing import Dict, List, Optional
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -24,7 +24,21 @@ CUDA_DEVICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||||
|
head_size: int) -> tuple[int, ...]:
|
||||||
|
return (batch_size, seq_len, num_heads * head_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||||
|
head_size: int) -> tuple[int, ...]:
|
||||||
|
return (batch_size, seq_len, num_heads, head_size)
|
||||||
|
|
||||||
|
|
||||||
|
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
|
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@ -36,6 +50,7 @@ CUDA_DEVICES = [
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_rotary_embedding(
|
def test_rotary_embedding(
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
|
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -58,10 +73,8 @@ def test_rotary_embedding(
|
|||||||
rope = rope.to(dtype=dtype)
|
rope = rope.to(dtype=dtype)
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
query = torch.randn(batch_size,
|
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||||
seq_len,
|
query = torch.randn(query_shape, dtype=dtype)
|
||||||
num_heads * head_size,
|
|
||||||
dtype=dtype)
|
|
||||||
key = torch.randn_like(query)
|
key = torch.randn_like(query)
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
@ -80,6 +93,7 @@ def test_rotary_embedding(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
|
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@ -91,6 +105,7 @@ def test_rotary_embedding(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_batched_rotary_embedding(
|
def test_batched_rotary_embedding(
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
|
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
|
|||||||
rope = rope.to(dtype=dtype)
|
rope = rope.to(dtype=dtype)
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
query = torch.randn(batch_size,
|
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||||
seq_len,
|
query = torch.randn(query_shape, dtype=dtype)
|
||||||
num_heads * head_size,
|
|
||||||
dtype=dtype)
|
|
||||||
key = torch.randn_like(query)
|
key = torch.randn_like(query)
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
|
|||||||
@ -424,24 +424,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def apply_pure_rope(
|
|
||||||
self,
|
|
||||||
input_positions: torch.Tensor,
|
|
||||||
q_pe: torch.Tensor,
|
|
||||||
k_pe: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
seq_len = input_positions.size(0)
|
|
||||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
|
||||||
|
|
||||||
q_pe, k_pe = self.rotary_emb(
|
|
||||||
input_positions,
|
|
||||||
q_pe.reshape(seq_len, -1),
|
|
||||||
k_pe.reshape(seq_len, -1),
|
|
||||||
)
|
|
||||||
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
|
|
||||||
|
|
||||||
return q_pe, k_pe
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
@ -466,14 +448,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
# Restore head dim (for rotary embedding)
|
# Restore head dim (for rotary embedding)
|
||||||
k_pe = k_pe.unsqueeze(1)
|
k_pe = k_pe.unsqueeze(1)
|
||||||
assert hasattr(attn_metadata, "input_positions")
|
assert hasattr(attn_metadata, "input_positions")
|
||||||
rope_fn = (self.rotary_emb
|
|
||||||
if self.use_yarn_rope else self.apply_pure_rope)
|
|
||||||
|
|
||||||
if is_decode:
|
if is_decode:
|
||||||
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||||
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
||||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||||
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||||
|
k_pe)
|
||||||
else:
|
else:
|
||||||
assert is_prefill
|
assert is_prefill
|
||||||
q = self.q_proj(hidden_states_or_q_c)[0]\
|
q = self.q_proj(hidden_states_or_q_c)[0]\
|
||||||
@ -481,7 +462,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
# TODO(lucas): there must be a nicer way to write this line
|
# TODO(lucas): there must be a nicer way to write this line
|
||||||
q[..., self.qk_nope_head_dim:], k_pe = \
|
q[..., self.qk_nope_head_dim:], k_pe = \
|
||||||
rope_fn(
|
self.rotary_emb(
|
||||||
attn_metadata.input_positions,
|
attn_metadata.input_positions,
|
||||||
q[..., self.qk_nope_head_dim:], k_pe)
|
q[..., self.qk_nope_head_dim:], k_pe)
|
||||||
|
|
||||||
|
|||||||
@ -257,9 +257,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
prefix=f"{prefix}.o_proj")
|
prefix=f"{prefix}.o_proj")
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
self.use_normal_rope = False
|
|
||||||
else:
|
|
||||||
self.use_normal_rope = True
|
|
||||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
@ -309,17 +307,8 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
k_pe = latent_cache[:, :, self.kv_lora_rank:]
|
k_pe = latent_cache[:, :, self.kv_lora_rank:]
|
||||||
|
|
||||||
if self.use_normal_rope:
|
|
||||||
seq_len = positions.size(0)
|
|
||||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
|
||||||
q_pe = q_pe.reshape(seq_len, -1)
|
|
||||||
k_pe = k_pe.reshape(seq_len, -1)
|
|
||||||
|
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
if self.use_normal_rope:
|
|
||||||
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
|
|
||||||
|
|
||||||
q[..., self.qk_nope_head_dim:] = q_pe
|
q[..., self.qk_nope_head_dim:] = q_pe
|
||||||
k = torch.empty_like(q)
|
k = torch.empty_like(q)
|
||||||
k[..., :self.qk_nope_head_dim] = k_nope
|
k[..., :self.qk_nope_head_dim] = k_nope
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user