diff --git a/csrc/ops.h b/csrc/ops.h index 2ee5df4cac54..3ecfd2cd9bf3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); -void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - std::optional key, - int64_t head_size, torch::Tensor& cos_sin_cache, - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets); - void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 266f2a0667a2..b5645b33b907 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel( token_idx, query_stride, key_stride, head_stride); } -template -__global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // nullptr or - // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int64_t head_stride, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = - cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - - apply_rotary_embedding( - query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride, head_stride); -} - } // namespace vllm void rotary_embedding( @@ -211,96 +182,3 @@ void rotary_embedding( } }); } - -/* -Batched version of rotary embedding, pack multiple LoRAs together -and process in batched manner. -*/ -void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - std::optional - key, // null or - // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size] -) { - // num_tokens = batch_size * seq_len - int64_t num_tokens = cos_sin_cache_offsets.size(0); - TORCH_CHECK( - positions.size(0) == num_tokens || positions.numel() == num_tokens, - "positions must have the same num_tokens or batch_size as " - "cos_sin_cache_offsets"); - - int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key - TORCH_CHECK( - positions_ndim == 1 || positions_ndim == 2, - "positions must have shape [num_tokens] or [batch_size, seq_len]"); - if (positions_ndim == 1) { - TORCH_CHECK(query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)), - "query, key and positions must have the same number of tokens"); - } - if (positions_ndim == 2) { - TORCH_CHECK( - query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)) && - query.size(1) == positions.size(1) && - (!key.has_value() || key->size(1) == positions.size(1)), - "query, key and positions must have the same batch_size and seq_len"); - } - - // Make sure head_size is valid for query and key - int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; - TORCH_CHECK(query_hidden_size % head_size == 0); - TORCH_CHECK(key_hidden_size % head_size == 0); - - // Make sure query and key have concistent number of heads - int num_heads = query_hidden_size / head_size; - int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; - TORCH_CHECK(num_heads % num_kv_heads == 0); - - int seq_dim_idx = positions_ndim - 1; - int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; - // Determine head stride: for [*, heads, head_size] use stride of last dim; - // for flat [*, heads*head_size], heads blocks are contiguous of size - // head_size - int query_ndim = query.dim(); - int64_t head_stride = - (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } else { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } - }); -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9e4586e8f045..c63fa7cffc78 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -214,16 +214,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - // Apply GPT-NeoX or GPT-J style rotary embedding to query and key - // (supports multiple loras). - ops.def( - "batched_rotary_embedding(Tensor positions, Tensor! query," - " Tensor!? key, int head_size," - " Tensor cos_sin_cache, bool is_neox," - " int rot_dim," - " Tensor cos_sin_cache_offsets) -> ()"); - ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); - // Quantization ops #ifndef USE_ROCM // Quantized GEMM for AWQ. diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index ab6f1ccf881f..bf9b1d9b4401 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate, product +from itertools import product from typing import Callable, Optional import pytest @@ -111,151 +111,6 @@ def test_rotary_embedding( "expected returned key to be None" -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding( - is_neox_style: bool, - tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": (1, ) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) - query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) if use_key else None - - # slice tensor if required, noop otherwise - query = query[..., :head_size] - key = key[..., :head_size] if use_key else None - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key) - out_query, out_key = rope.forward(positions, - query, - key, - offsets=torch.zeros(batch_size * seq_len, - dtype=torch.long, - device=device)) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - scaling_factors: list[int] = [1, 2, 4] - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) - key = torch.randn_like(query) if use_key else None - - offset_map = torch.tensor( - list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) - query_offsets = offset_map[query_types] - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key, - query_offsets) - out_query, out_key = rope.forward(positions, query, key, - query_offsets.flatten()) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - @torch.inference_mode() def test_rope_module_cache(): MAX_POSITIONS = [123, 1234] diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index d1fd960bf115..5857dd5ba3fa 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding def rotary_embedding_opcheck(rot, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None): + key: Optional[torch.Tensor] = None): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - opcheck(torch.ops._C.batched_rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style, rot.rotary_dim, offsets)) - else: - opcheck(torch.ops._C.rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style)) + # ops.rotary_embedding() is a in-place operation + # that updates the query and key tensors. + opcheck(torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style)) @pytest.mark.parametrize("device", ["cuda"]) @@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) - offsets = torch.zeros(batch_size * seq_len, - device=device, - dtype=torch.long) - rotary_embedding_opcheck(rot, positions, query, key, offsets) # if we have a contiguous head stride, test the alternate # [..., num_heads * head_dim] shape/layout diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3b2e859b76af..93b4f87ed260 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -257,16 +257,6 @@ def rotary_embedding( cos_sin_cache, is_neox) -def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor], head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) - - # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index c2868c040aa1..59b0aed32150 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -148,17 +148,6 @@ class ipex_ops: head_size, cos_sin_cache, is_neox, rot_dim) - @staticmethod - def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - ipex.llm.functional.rotary_embedding_batched(positions, query, key, - head_size, cos_sin_cache, - is_neox, rot_dim, - cos_sin_cache_offsets) - @staticmethod def rms_norm(input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index be25e90abf82..db50eb08db3f 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) @@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm import _custom_ops as ops @@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def forward_xpu( @@ -124,29 +114,21 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. if key is None: # XPU kernel doesn't support key=None so fall back to native impl # TODO(sarckk): add support for optional key in # ipex.llm.functional.rotary_embedding_batched - return self.forward_native(positions, query, key, offsets) + return self.forward_native(positions, query, key) else: - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def extra_repr(self) -> str: