mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:35:28 +08:00
[Chore] Remove unused batched RoPE op & kernel (#24789)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
99bfef841f
commit
5febdc8750
@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
|||||||
std::optional<torch::Tensor> key, int64_t head_size,
|
std::optional<torch::Tensor> key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
|
|
||||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
|
||||||
std::optional<torch::Tensor> key,
|
|
||||||
int64_t head_size, torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox, int64_t rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets);
|
|
||||||
|
|
||||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
|
|||||||
@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
|
|||||||
token_idx, query_stride, key_stride, head_stride);
|
token_idx, query_stride, key_stride, head_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
|
||||||
__global__ void batched_rotary_embedding_kernel(
|
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
|
||||||
// [num_tokens]
|
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
|
||||||
// head_size] or [num_tokens, num_heads,
|
|
||||||
// head_size]
|
|
||||||
scalar_t* __restrict__ key, // nullptr or
|
|
||||||
// [batch_size, seq_len, num_kv_heads,
|
|
||||||
// head_size] or [num_tokens, num_kv_heads,
|
|
||||||
// head_size]
|
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
|
||||||
// 2]
|
|
||||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
|
||||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
|
||||||
const int head_size) {
|
|
||||||
// Each thread block is responsible for one token.
|
|
||||||
const int token_idx = blockIdx.x;
|
|
||||||
int64_t pos = positions[token_idx];
|
|
||||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
|
||||||
const scalar_t* cache_ptr =
|
|
||||||
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
|
||||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
|
||||||
token_idx, query_stride, key_stride, head_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
@ -211,96 +182,3 @@ void rotary_embedding(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
Batched version of rotary embedding, pack multiple LoRAs together
|
|
||||||
and process in batched manner.
|
|
||||||
*/
|
|
||||||
void batched_rotary_embedding(
|
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
|
||||||
// [num_tokens, num_heads * head_size] or
|
|
||||||
// [batch_size, seq_len, num_heads, head_size] or
|
|
||||||
// [num_tokens, num_heads, head_size]
|
|
||||||
std::optional<torch::Tensor>
|
|
||||||
key, // null or
|
|
||||||
// [batch_size, seq_len, num_kv_heads * head_size] or
|
|
||||||
// [num_tokens, num_kv_heads * head_size] or
|
|
||||||
// [batch_size, seq_len, num_heads, head_size] or
|
|
||||||
// [num_tokens, num_heads, head_size]
|
|
||||||
int64_t head_size,
|
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
|
||||||
bool is_neox, int64_t rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
|
|
||||||
) {
|
|
||||||
// num_tokens = batch_size * seq_len
|
|
||||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
|
||||||
TORCH_CHECK(
|
|
||||||
positions.size(0) == num_tokens || positions.numel() == num_tokens,
|
|
||||||
"positions must have the same num_tokens or batch_size as "
|
|
||||||
"cos_sin_cache_offsets");
|
|
||||||
|
|
||||||
int positions_ndim = positions.dim();
|
|
||||||
// Make sure num_tokens dim is consistent across positions, query, and key
|
|
||||||
TORCH_CHECK(
|
|
||||||
positions_ndim == 1 || positions_ndim == 2,
|
|
||||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
|
||||||
if (positions_ndim == 1) {
|
|
||||||
TORCH_CHECK(query.size(0) == positions.size(0) &&
|
|
||||||
(!key.has_value() || key->size(0) == positions.size(0)),
|
|
||||||
"query, key and positions must have the same number of tokens");
|
|
||||||
}
|
|
||||||
if (positions_ndim == 2) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
query.size(0) == positions.size(0) &&
|
|
||||||
(!key.has_value() || key->size(0) == positions.size(0)) &&
|
|
||||||
query.size(1) == positions.size(1) &&
|
|
||||||
(!key.has_value() || key->size(1) == positions.size(1)),
|
|
||||||
"query, key and positions must have the same batch_size and seq_len");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure head_size is valid for query and key
|
|
||||||
int query_hidden_size = query.numel() / num_tokens;
|
|
||||||
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
|
|
||||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
|
||||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
|
||||||
|
|
||||||
// Make sure query and key have concistent number of heads
|
|
||||||
int num_heads = query_hidden_size / head_size;
|
|
||||||
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
|
|
||||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
|
||||||
|
|
||||||
int seq_dim_idx = positions_ndim - 1;
|
|
||||||
int64_t query_stride = query.stride(seq_dim_idx);
|
|
||||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
|
||||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
|
||||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
|
||||||
// head_size
|
|
||||||
int query_ndim = query.dim();
|
|
||||||
int64_t head_stride =
|
|
||||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
|
||||||
if (is_neox) {
|
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
|
||||||
<<<grid, block, 0, stream>>>(
|
|
||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
|
||||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
|
||||||
} else {
|
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
|
||||||
<<<grid, block, 0, stream>>>(
|
|
||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
|
||||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|||||||
@ -214,16 +214,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||||
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||||
|
|
||||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
|
|
||||||
// (supports multiple loras).
|
|
||||||
ops.def(
|
|
||||||
"batched_rotary_embedding(Tensor positions, Tensor! query,"
|
|
||||||
" Tensor!? key, int head_size,"
|
|
||||||
" Tensor cos_sin_cache, bool is_neox,"
|
|
||||||
" int rot_dim,"
|
|
||||||
" Tensor cos_sin_cache_offsets) -> ()");
|
|
||||||
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
|
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// Quantized GEMM for AWQ.
|
// Quantized GEMM for AWQ.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from itertools import accumulate, product
|
from itertools import product
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -111,151 +111,6 @@ def test_rotary_embedding(
|
|||||||
"expected returned key to be None"
|
"expected returned key to be None"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
|
||||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
||||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
||||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_batched_rotary_embedding(
|
|
||||||
is_neox_style: bool,
|
|
||||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
|
||||||
batch_size: int,
|
|
||||||
seq_len: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
rotary_dim: Optional[int],
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
use_key: bool,
|
|
||||||
max_position: int = 8192,
|
|
||||||
base: float = 10000,
|
|
||||||
) -> None:
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
if rotary_dim is None:
|
|
||||||
rotary_dim = head_size
|
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
|
||||||
"rope_type": "linear",
|
|
||||||
"factor": (1, )
|
|
||||||
})
|
|
||||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
|
||||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
|
||||||
query = torch.randn(query_shape, dtype=dtype)
|
|
||||||
key = torch.randn_like(query) if use_key else None
|
|
||||||
|
|
||||||
# slice tensor if required, noop otherwise
|
|
||||||
query = query[..., :head_size]
|
|
||||||
key = key[..., :head_size] if use_key else None
|
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
|
||||||
# because the custom kernel is in-place.
|
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
|
||||||
out_query, out_key = rope.forward(positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
offsets=torch.zeros(batch_size * seq_len,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device))
|
|
||||||
# Compare the results.
|
|
||||||
torch.testing.assert_close(out_query,
|
|
||||||
ref_query,
|
|
||||||
atol=get_default_atol(out_query),
|
|
||||||
rtol=get_default_rtol(out_query))
|
|
||||||
if use_key:
|
|
||||||
torch.testing.assert_close(out_key,
|
|
||||||
ref_key,
|
|
||||||
atol=get_default_atol(out_key),
|
|
||||||
rtol=get_default_rtol(out_key))
|
|
||||||
else:
|
|
||||||
assert ref_key is None and out_key is None, \
|
|
||||||
"expected returned key to be None"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
||||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
||||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_batched_rotary_embedding_multi_lora(
|
|
||||||
is_neox_style: bool,
|
|
||||||
batch_size: int,
|
|
||||||
seq_len: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
rotary_dim: Optional[int],
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
use_key: bool,
|
|
||||||
max_position: int = 8192,
|
|
||||||
base: float = 10000,
|
|
||||||
) -> None:
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
if rotary_dim is None:
|
|
||||||
rotary_dim = head_size
|
|
||||||
scaling_factors: list[int] = [1, 2, 4]
|
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
|
||||||
"rope_type": "linear",
|
|
||||||
"factor": tuple(scaling_factors)
|
|
||||||
})
|
|
||||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
|
||||||
query = torch.randn(batch_size,
|
|
||||||
seq_len,
|
|
||||||
num_heads * head_size,
|
|
||||||
dtype=dtype)
|
|
||||||
key = torch.randn_like(query) if use_key else None
|
|
||||||
|
|
||||||
offset_map = torch.tensor(
|
|
||||||
list(
|
|
||||||
accumulate([0] + [
|
|
||||||
max_position * scaling_factor * 2
|
|
||||||
for scaling_factor in scaling_factors[:-1]
|
|
||||||
])))
|
|
||||||
query_types = torch.randint(0,
|
|
||||||
len(scaling_factors), (batch_size, seq_len),
|
|
||||||
device=device)
|
|
||||||
query_offsets = offset_map[query_types]
|
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
|
||||||
# because the custom kernel is in-place.
|
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key,
|
|
||||||
query_offsets)
|
|
||||||
out_query, out_key = rope.forward(positions, query, key,
|
|
||||||
query_offsets.flatten())
|
|
||||||
# Compare the results.
|
|
||||||
torch.testing.assert_close(out_query,
|
|
||||||
ref_query,
|
|
||||||
atol=get_default_atol(out_query),
|
|
||||||
rtol=get_default_rtol(out_query))
|
|
||||||
if use_key:
|
|
||||||
torch.testing.assert_close(out_key,
|
|
||||||
ref_key,
|
|
||||||
atol=get_default_atol(out_key),
|
|
||||||
rtol=get_default_rtol(out_key))
|
|
||||||
else:
|
|
||||||
assert ref_key is None and out_key is None, \
|
|
||||||
"expected returned key to be None"
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_rope_module_cache():
|
def test_rope_module_cache():
|
||||||
MAX_POSITIONS = [123, 1234]
|
MAX_POSITIONS = [123, 1234]
|
||||||
|
|||||||
@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|||||||
def rotary_embedding_opcheck(rot,
|
def rotary_embedding_opcheck(rot,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None):
|
||||||
offsets: Optional[torch.Tensor] = None):
|
|
||||||
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
|
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding() is a in-place operation
|
||||||
# are in-place operations that update the query and key tensors.
|
# that updates the query and key tensors.
|
||||||
if offsets is not None:
|
opcheck(torch.ops._C.rotary_embedding,
|
||||||
opcheck(torch.ops._C.batched_rotary_embedding,
|
(positions, query, key, rot.head_size, cos_sin_cache,
|
||||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
rot.is_neox_style))
|
||||||
rot.is_neox_style, rot.rotary_dim, offsets))
|
|
||||||
else:
|
|
||||||
opcheck(torch.ops._C.rotary_embedding,
|
|
||||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
|
||||||
rot.is_neox_style))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
|||||||
key = key[..., :head_size] if use_key else None
|
key = key[..., :head_size] if use_key else None
|
||||||
|
|
||||||
rotary_embedding_opcheck(rot, positions, query, key)
|
rotary_embedding_opcheck(rot, positions, query, key)
|
||||||
offsets = torch.zeros(batch_size * seq_len,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long)
|
|
||||||
rotary_embedding_opcheck(rot, positions, query, key, offsets)
|
|
||||||
|
|
||||||
# if we have a contiguous head stride, test the alternate
|
# if we have a contiguous head stride, test the alternate
|
||||||
# [..., num_heads * head_dim] shape/layout
|
# [..., num_heads * head_dim] shape/layout
|
||||||
|
|||||||
@ -257,16 +257,6 @@ def rotary_embedding(
|
|||||||
cos_sin_cache, is_neox)
|
cos_sin_cache, is_neox)
|
||||||
|
|
||||||
|
|
||||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|
||||||
key: Optional[torch.Tensor], head_size: int,
|
|
||||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
|
||||||
rot_dim: int,
|
|
||||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
|
||||||
cos_sin_cache, is_neox, rot_dim,
|
|
||||||
cos_sin_cache_offsets)
|
|
||||||
|
|
||||||
|
|
||||||
# layer norm ops
|
# layer norm ops
|
||||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||||
epsilon: float) -> None:
|
epsilon: float) -> None:
|
||||||
|
|||||||
@ -148,17 +148,6 @@ class ipex_ops:
|
|||||||
head_size, cos_sin_cache,
|
head_size, cos_sin_cache,
|
||||||
is_neox, rot_dim)
|
is_neox, rot_dim)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|
||||||
key: torch.Tensor, head_size: int,
|
|
||||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
|
||||||
rot_dim: int,
|
|
||||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
|
||||||
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
|
||||||
head_size, cos_sin_cache,
|
|
||||||
is_neox, rot_dim,
|
|
||||||
cos_sin_cache_offsets)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
||||||
epsilon: float) -> torch.Tensor:
|
epsilon: float) -> torch.Tensor:
|
||||||
|
|||||||
@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""A PyTorch-native implementation of forward()."""
|
"""A PyTorch-native implementation of forward()."""
|
||||||
if offsets is not None:
|
|
||||||
positions = positions + offsets
|
|
||||||
positions = positions.flatten()
|
positions = positions.flatten()
|
||||||
num_tokens = positions.shape[0]
|
num_tokens = positions.shape[0]
|
||||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||||
@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||||
dtype=query.dtype)
|
dtype=query.dtype)
|
||||||
|
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding() is an in-place operation
|
||||||
# are in-place operations that update the query and key tensors.
|
# that updates the query and key tensors.
|
||||||
if offsets is not None:
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style, self.rotary_dim,
|
|
||||||
offsets)
|
|
||||||
else:
|
|
||||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
@ -124,29 +114,21 @@ class RotaryEmbedding(CustomOp):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||||
dtype=query.dtype)
|
dtype=query.dtype)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding() is an in-place operation
|
||||||
# are in-place operations that update the query and key tensors.
|
# that updates the query and key tensors.
|
||||||
if key is None:
|
if key is None:
|
||||||
# XPU kernel doesn't support key=None so fall back to native impl
|
# XPU kernel doesn't support key=None so fall back to native impl
|
||||||
# TODO(sarckk): add support for optional key in
|
# TODO(sarckk): add support for optional key in
|
||||||
# ipex.llm.functional.rotary_embedding_batched
|
# ipex.llm.functional.rotary_embedding_batched
|
||||||
return self.forward_native(positions, query, key, offsets)
|
return self.forward_native(positions, query, key)
|
||||||
else:
|
else:
|
||||||
if offsets is not None:
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
ops.batched_rotary_embedding(positions, query, key,
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style,
|
|
||||||
self.rotary_dim, offsets)
|
|
||||||
else:
|
|
||||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user