vllm/tests/kernels/core/test_pos_encoding.py
Harry Mellor a8b70304d6
Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-11-19 09:06:36 -08:00

190 lines
5.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from itertools import product
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 120, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
USE_KEY = [True, False]
def _get_flat_tensor_shape(
batch_size: int, seq_len: int, num_heads: int, head_size: int
) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads * head_size)
# For testing sliced tensors
def _get_padded_tensor_shape(
batch_size: int, seq_len: int, num_heads: int, head_size: int
) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size + 64)
def _get_batch_tensor_shape(
batch_size: int, seq_len: int, num_heads: int, head_size: int
) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [
_get_batch_tensor_shape,
_get_flat_tensor_shape,
_get_padded_tensor_shape,
]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]],
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: int | None,
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
rope_theta: float = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
torch.testing.assert_close(
out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query),
)
if use_key:
torch.testing.assert_close(
out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key),
)
else:
assert ref_key is None and out_key is None, "expected returned key to be None"
@torch.inference_mode()
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
ROPE_THETAS = [10000, 1000000]
ROPE_PARAMETERS = (
{"rope_type": "default"},
{"rope_type": "linear", "factor": (1,)},
{"rope_type": "dynamic", "factor": 1},
)
settings = (
HEAD_SIZES,
ROTARY_DIMS,
MAX_POSITIONS,
ROPE_THETAS,
IS_NEOX_STYLE,
ROPE_PARAMETERS,
DTYPES,
)
rope_setting_id_map: dict[str, int] = {}
for setting in product(*settings):
(
head_size,
rotary_dim,
max_position,
rope_theta,
is_neox_style,
rope_parameters,
dtype,
) = setting
if rotary_dim is None:
rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope = get_rope(
head_size,
rotary_dim,
max_position,
is_neox_style,
rope_parameters,
dtype,
)
# different settings cannot share the same rope module
assert id(rope) not in rope_setting_id_map.values()
assert all(x.dtype == dtype for x in rope.buffers())
assert all(x.dtype == dtype for x in rope.parameters())
rope_setting_id_map[str(setting)] = id(rope)
for setting in product(*settings):
(
head_size,
rotary_dim,
max_position,
rope_theta,
is_neox_style,
rope_parameters,
dtype,
) = setting
if rotary_dim is None:
rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope = get_rope(
head_size,
rotary_dim,
max_position,
is_neox_style,
rope_parameters,
dtype,
)
# check if cache take effect
assert id(rope) == rope_setting_id_map[str(setting)]