mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:34:55 +08:00
261 lines
8.9 KiB
Python
261 lines
8.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for the shuffle_rows function
|
|
|
|
Run `pytest tests/kernels/test_shuffle_rows.py`.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm._custom_ops import shuffle_rows
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [1, 16, 64, 128, 256, 512, 1024])
|
|
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096])
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, dtype: torch.dtype):
|
|
"""Test basic functionality of shuffle_rows with various tensor sizes and
|
|
dtypes."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
# Create input tensor
|
|
input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
|
|
|
# Create a simple permutation map (identity mapping)
|
|
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# With identity mapping, output should be identical to input
|
|
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
|
|
|
|
# Check output shape
|
|
assert output.shape == (num_tokens, hidden_size)
|
|
assert output.dtype == dtype
|
|
assert output.device == input_tensor.device
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [16, 64, 128])
|
|
@pytest.mark.parametrize("hidden_size", [128, 512, 1024])
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
def test_shuffle_rows_permutation(
|
|
num_tokens: int, hidden_size: int, dtype: torch.dtype
|
|
):
|
|
"""Test shuffle_rows with actual permutation."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
# Create input tensor
|
|
input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
|
|
|
# Create a reverse permutation map
|
|
dst2src_map = torch.arange(num_tokens - 1, -1, -1, device="cuda", dtype=torch.int32)
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# Check that the output is the reverse of the input
|
|
expected_output = torch.flip(input_tensor, dims=[0])
|
|
torch.testing.assert_close(output, expected_output, atol=1e-6, rtol=1e-5)
|
|
|
|
# Check output shape and properties
|
|
assert output.shape == (num_tokens, hidden_size)
|
|
assert output.dtype == dtype
|
|
assert output.device == input_tensor.device
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [32, 64])
|
|
@pytest.mark.parametrize("hidden_size", [256, 512])
|
|
def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int):
|
|
"""Test shuffle_rows with expansion (more output tokens than input
|
|
tokens)."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
dtype = torch.float16
|
|
|
|
# Create input tensor
|
|
input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
|
|
|
# Create a mapping that duplicates some tokens (expansion)
|
|
expanded_size = num_tokens * 2
|
|
dst2src_map = torch.randint(
|
|
0, num_tokens, (expanded_size,), device="cuda", dtype=torch.int32
|
|
)
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# Check output shape
|
|
assert output.shape == (expanded_size, hidden_size)
|
|
assert output.dtype == dtype
|
|
assert output.device == input_tensor.device
|
|
|
|
# Verify that each output row matches the corresponding input row
|
|
for i in range(expanded_size):
|
|
src_idx = dst2src_map[i].item()
|
|
torch.testing.assert_close(
|
|
output[i], input_tensor[src_idx], atol=1e-6, rtol=1e-5
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [16, 64])
|
|
@pytest.mark.parametrize("hidden_size", [128, 512])
|
|
def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int):
|
|
"""Test shuffle_rows with random permutation."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
dtype = torch.float16
|
|
|
|
# Set seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create input tensor
|
|
input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
|
|
|
# Create a random permutation map
|
|
dst2src_map = torch.randperm(num_tokens, device="cuda", dtype=torch.int32)
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# Check output shape and properties
|
|
assert output.shape == (num_tokens, hidden_size)
|
|
assert output.dtype == dtype
|
|
assert output.device == input_tensor.device
|
|
|
|
# Verify that each output row matches the corresponding input row
|
|
for i in range(num_tokens):
|
|
src_idx = dst2src_map[i].item()
|
|
torch.testing.assert_close(
|
|
output[i], input_tensor[src_idx], atol=1e-6, rtol=1e-5
|
|
)
|
|
|
|
|
|
def test_shuffle_rows_edge_cases():
|
|
"""Test shuffle_rows with edge cases."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
dtype = torch.float16
|
|
|
|
# Test with single token
|
|
input_tensor = torch.randn(1, 128, device="cuda", dtype=dtype)
|
|
dst2src_map = torch.tensor([0], device="cuda", dtype=torch.int32)
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
|
|
|
|
# Test with single feature dimension
|
|
input_tensor = torch.randn(16, 1, device="cuda", dtype=dtype)
|
|
dst2src_map = torch.arange(16, device="cuda", dtype=torch.int32)
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
|
|
|
|
|
|
def test_shuffle_rows_moe_like_scenario():
|
|
"""Test shuffle_rows in a scenario similar to MoE usage."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
dtype = torch.float16
|
|
batch_size = 32
|
|
hidden_size = 1024
|
|
topk = 2
|
|
|
|
# Simulate input tokens
|
|
input_tensor = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
|
|
|
|
# Simulate expert assignment (each token goes to topk experts)
|
|
# This creates a mapping where tokens are duplicated for multiple experts
|
|
total_tokens = batch_size * topk
|
|
dst2src_map = torch.zeros(total_tokens, device="cuda", dtype=torch.int32)
|
|
|
|
# Fill the mapping to simulate MoE token distribution
|
|
for i in range(batch_size):
|
|
for k in range(topk):
|
|
dst2src_map[i * topk + k] = i
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# Check output shape
|
|
assert output.shape == (total_tokens, hidden_size)
|
|
assert output.dtype == dtype
|
|
assert output.device == input_tensor.device
|
|
|
|
# Verify that tokens are correctly duplicated
|
|
for i in range(batch_size):
|
|
for k in range(topk):
|
|
output_idx = i * topk + k
|
|
torch.testing.assert_close(
|
|
output[output_idx], input_tensor[i], atol=1e-6, rtol=1e-5
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
def test_shuffle_rows_dtype_consistency(dtype: torch.dtype):
|
|
"""Test that shuffle_rows preserves dtype correctly."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
num_tokens = 64
|
|
hidden_size = 512
|
|
|
|
# Create input tensor with specific dtype
|
|
input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
|
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# Verify dtype is preserved
|
|
assert output.dtype == dtype
|
|
assert output.device == input_tensor.device
|
|
torch.testing.assert_close(output, input_tensor, atol=1e-6, rtol=1e-5)
|
|
|
|
|
|
def test_shuffle_rows_device_consistency():
|
|
"""Test that shuffle_rows maintains device consistency."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
num_tokens = 32
|
|
hidden_size = 256
|
|
dtype = torch.float16
|
|
|
|
# Create input tensor on CUDA
|
|
input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
|
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# Verify device is maintained
|
|
assert output.device == input_tensor.device
|
|
assert output.device.type == "cuda"
|
|
|
|
|
|
def test_shuffle_rows_contiguous_output():
|
|
"""Test that shuffle_rows produces contiguous output."""
|
|
if not current_platform.is_cuda():
|
|
pytest.skip("shuffle_rows requires CUDA")
|
|
|
|
num_tokens = 64
|
|
hidden_size = 512
|
|
dtype = torch.float16
|
|
|
|
# Create input tensor
|
|
input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
|
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
|
|
|
# Test shuffle_rows
|
|
output = shuffle_rows(input_tensor, dst2src_map)
|
|
|
|
# Verify output is contiguous
|
|
assert output.is_contiguous()
|