mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 19:14:32 +08:00
[Bugfix] Fix CUDA arch flags for MoE permute (#21426)
Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
parent
13abd0eaf9
commit
2ded067fd2
@ -635,7 +635,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
@ -842,8 +842,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/moe/moe_permute_unpermute_op.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_PERMUTE_SRC}"
|
||||
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
|
||||
SRCS "${MOE_PERMUTE_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
|
||||
endif()
|
||||
|
||||
294
tests/kernels/test_shuffle_rows.py
Normal file
294
tests/kernels/test_shuffle_rows.py
Normal file
@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the shuffle_rows function
|
||||
|
||||
Run `pytest tests/kernels/test_shuffle_rows.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import shuffle_rows
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 16, 64, 128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_shuffle_rows_basic(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype):
|
||||
"""Test basic functionality of shuffle_rows with various tensor sizes and
|
||||
dtypes."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
|
||||
# Create a simple permutation map (identity mapping)
|
||||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# With identity mapping, output should be identical to input
|
||||
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == (num_tokens, hidden_size)
|
||||
assert output.dtype == dtype
|
||||
assert output.device == input_tensor.device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [16, 64, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 512, 1024])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_shuffle_rows_permutation(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype):
|
||||
"""Test shuffle_rows with actual permutation."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
|
||||
# Create a reverse permutation map
|
||||
dst2src_map = torch.arange(num_tokens - 1,
|
||||
-1,
|
||||
-1,
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# Check that the output is the reverse of the input
|
||||
expected_output = torch.flip(input_tensor, dims=[0])
|
||||
torch.testing.assert_close(output, expected_output, atol=1e-6, rtol=1e-5)
|
||||
|
||||
# Check output shape and properties
|
||||
assert output.shape == (num_tokens, hidden_size)
|
||||
assert output.dtype == dtype
|
||||
assert output.device == input_tensor.device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [32, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [256, 512])
|
||||
def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int):
|
||||
"""Test shuffle_rows with expansion (more output tokens than input
|
||||
tokens)."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
|
||||
# Create a mapping that duplicates some tokens (expansion)
|
||||
expanded_size = num_tokens * 2
|
||||
dst2src_map = torch.randint(0,
|
||||
num_tokens, (expanded_size, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == (expanded_size, hidden_size)
|
||||
assert output.dtype == dtype
|
||||
assert output.device == input_tensor.device
|
||||
|
||||
# Verify that each output row matches the corresponding input row
|
||||
for i in range(expanded_size):
|
||||
src_idx = dst2src_map[i].item()
|
||||
torch.testing.assert_close(output[i],
|
||||
input_tensor[src_idx],
|
||||
atol=1e-6,
|
||||
rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [16, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 512])
|
||||
def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int):
|
||||
"""Test shuffle_rows with random permutation."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
|
||||
# Create a random permutation map
|
||||
dst2src_map = torch.randperm(num_tokens, device="cuda", dtype=torch.int32)
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# Check output shape and properties
|
||||
assert output.shape == (num_tokens, hidden_size)
|
||||
assert output.dtype == dtype
|
||||
assert output.device == input_tensor.device
|
||||
|
||||
# Verify that each output row matches the corresponding input row
|
||||
for i in range(num_tokens):
|
||||
src_idx = dst2src_map[i].item()
|
||||
torch.testing.assert_close(output[i],
|
||||
input_tensor[src_idx],
|
||||
atol=1e-6,
|
||||
rtol=1e-5)
|
||||
|
||||
|
||||
def test_shuffle_rows_edge_cases():
|
||||
"""Test shuffle_rows with edge cases."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
# Test with single token
|
||||
input_tensor = torch.randn(1, 128, device="cuda", dtype=dtype)
|
||||
dst2src_map = torch.tensor([0], device="cuda", dtype=torch.int32)
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
|
||||
|
||||
# Test with single feature dimension
|
||||
input_tensor = torch.randn(16, 1, device="cuda", dtype=dtype)
|
||||
dst2src_map = torch.arange(16, device="cuda", dtype=torch.int32)
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
torch.testing.assert_close(output, input_tensor, atol=0, rtol=0)
|
||||
|
||||
|
||||
def test_shuffle_rows_moe_like_scenario():
|
||||
"""Test shuffle_rows in a scenario similar to MoE usage."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
dtype = torch.float16
|
||||
batch_size = 32
|
||||
hidden_size = 1024
|
||||
topk = 2
|
||||
|
||||
# Simulate input tokens
|
||||
input_tensor = torch.randn(batch_size,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
|
||||
# Simulate expert assignment (each token goes to topk experts)
|
||||
# This creates a mapping where tokens are duplicated for multiple experts
|
||||
total_tokens = batch_size * topk
|
||||
dst2src_map = torch.zeros(total_tokens, device="cuda", dtype=torch.int32)
|
||||
|
||||
# Fill the mapping to simulate MoE token distribution
|
||||
for i in range(batch_size):
|
||||
for k in range(topk):
|
||||
dst2src_map[i * topk + k] = i
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == (total_tokens, hidden_size)
|
||||
assert output.dtype == dtype
|
||||
assert output.device == input_tensor.device
|
||||
|
||||
# Verify that tokens are correctly duplicated
|
||||
for i in range(batch_size):
|
||||
for k in range(topk):
|
||||
output_idx = i * topk + k
|
||||
torch.testing.assert_close(output[output_idx],
|
||||
input_tensor[i],
|
||||
atol=1e-6,
|
||||
rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_shuffle_rows_dtype_consistency(dtype: torch.dtype):
|
||||
"""Test that shuffle_rows preserves dtype correctly."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
num_tokens = 64
|
||||
hidden_size = 512
|
||||
|
||||
# Create input tensor with specific dtype
|
||||
input_tensor = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# Verify dtype is preserved
|
||||
assert output.dtype == dtype
|
||||
assert output.device == input_tensor.device
|
||||
torch.testing.assert_close(output, input_tensor, atol=1e-6, rtol=1e-5)
|
||||
|
||||
|
||||
def test_shuffle_rows_device_consistency():
|
||||
"""Test that shuffle_rows maintains device consistency."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
num_tokens = 32
|
||||
hidden_size = 256
|
||||
dtype = torch.float16
|
||||
|
||||
# Create input tensor on CUDA
|
||||
input_tensor = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# Verify device is maintained
|
||||
assert output.device == input_tensor.device
|
||||
assert output.device.type == "cuda"
|
||||
|
||||
|
||||
def test_shuffle_rows_contiguous_output():
|
||||
"""Test that shuffle_rows produces contiguous output."""
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("shuffle_rows requires CUDA")
|
||||
|
||||
num_tokens = 64
|
||||
hidden_size = 512
|
||||
dtype = torch.float16
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32)
|
||||
|
||||
# Test shuffle_rows
|
||||
output = shuffle_rows(input_tensor, dst2src_map)
|
||||
|
||||
# Verify output is contiguous
|
||||
assert output.is_contiguous()
|
||||
Loading…
x
Reference in New Issue
Block a user