mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Kernel] Make rotary_embedding ops more flexible with input shape (#12777)
This commit is contained in:
parent
1e57b1ee63
commit
85ac82d228
@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
// [num_tokens, num_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size]
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.numel() / query.size(-1);
|
||||
// num_tokens = batch_size * seq_len
|
||||
int64_t num_tokens = positions.numel();
|
||||
int positions_ndim = positions.dim();
|
||||
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key.
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
key.size(0) == positions.size(0) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
key.size(1) == positions.size(1),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
// Make sure head_size is valid for query and key
|
||||
// hidden_size = num_heads * head_size
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.numel() / num_tokens;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
|
||||
// Make sure query and key have consistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key_hidden_size / head_size;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.stride(seq_dim_idx);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
@ -165,19 +201,58 @@ and process in batched manner.
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
// [num_tokens, num_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size]
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox, int64_t rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
|
||||
) {
|
||||
// num_tokens = batch_size * seq_len
|
||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
TORCH_CHECK(
|
||||
positions.size(0) == num_tokens || positions.numel() == num_tokens,
|
||||
"positions must have the same num_tokens or batch_size as "
|
||||
"cos_sin_cache_offsets");
|
||||
|
||||
int positions_ndim = positions.dim();
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key.
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
key.size(0) == positions.size(0) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
key.size(1) == positions.size(1),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
// Make sure head_size is valid for query and key
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.numel() / num_tokens;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
|
||||
// Make sure query and key have concistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key_hidden_size / head_size;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.stride(seq_dim_idx);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from itertools import accumulate, product
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -24,7 +24,21 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads * head_size)
|
||||
|
||||
|
||||
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size)
|
||||
|
||||
|
||||
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@ -36,6 +50,7 @@ CUDA_DEVICES = [
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
@ -58,10 +73,8 @@ def test_rotary_embedding(
|
||||
rope = rope.to(dtype=dtype)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
@ -80,6 +93,7 @@ def test_rotary_embedding(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@ -91,6 +105,7 @@ def test_rotary_embedding(
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
|
||||
rope = rope.to(dtype=dtype)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
|
||||
@ -424,24 +424,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_pure_rope(
|
||||
self,
|
||||
input_positions: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_len = input_positions.size(0)
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(
|
||||
input_positions,
|
||||
q_pe.reshape(seq_len, -1),
|
||||
k_pe.reshape(seq_len, -1),
|
||||
)
|
||||
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
|
||||
|
||||
return q_pe, k_pe
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@ -466,14 +448,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
assert hasattr(attn_metadata, "input_positions")
|
||||
rope_fn = (self.rotary_emb
|
||||
if self.use_yarn_rope else self.apply_pure_rope)
|
||||
|
||||
if is_decode:
|
||||
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
else:
|
||||
assert is_prefill
|
||||
q = self.q_proj(hidden_states_or_q_c)[0]\
|
||||
@ -481,7 +462,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
# TODO(lucas): there must be a nicer way to write this line
|
||||
q[..., self.qk_nope_head_dim:], k_pe = \
|
||||
rope_fn(
|
||||
self.rotary_emb(
|
||||
attn_metadata.input_positions,
|
||||
q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
|
||||
@ -257,9 +257,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj")
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.use_normal_rope = False
|
||||
else:
|
||||
self.use_normal_rope = True
|
||||
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
@ -309,17 +307,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = latent_cache[:, :, self.kv_lora_rank:]
|
||||
|
||||
if self.use_normal_rope:
|
||||
seq_len = positions.size(0)
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(seq_len, -1)
|
||||
k_pe = k_pe.reshape(seq_len, -1)
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
if self.use_normal_rope:
|
||||
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
|
||||
|
||||
q[..., self.qk_nope_head_dim:] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
k[..., :self.qk_nope_head_dim] = k_nope
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user