mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import random
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
|
|
def round_up(x, base):
|
|
return ((x + base - 1) // base) * base
|
|
|
|
|
|
def CEILDIV(x, y):
|
|
return (x + y - 1) // y
|
|
|
|
|
|
def sample_data(num_experts, max_loras, num_tokens, topk_num):
|
|
topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32)
|
|
token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32)
|
|
|
|
for i in range(num_tokens):
|
|
pool = list(range(num_experts))
|
|
random.shuffle(pool)
|
|
for j in range(topk_num):
|
|
topk_ids[i, j] = pool[j]
|
|
token_lora_mapping[i] = random.randint(0, max_loras - 1)
|
|
|
|
return topk_ids.to("cuda"), token_lora_mapping.to("cuda")
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
|
|
@pytest.mark.parametrize("topk_num", [6])
|
|
@pytest.mark.parametrize("num_experts", [64, 128])
|
|
@pytest.mark.parametrize("max_loras", [2, 32])
|
|
@pytest.mark.parametrize("block_size", [16])
|
|
def test_moe_lora_align_block_size(
|
|
num_tokens, topk_num, num_experts, max_loras, block_size
|
|
):
|
|
# sample data
|
|
random.seed(1)
|
|
topk_ids, token_lora_mapping = sample_data(
|
|
num_experts, max_loras, num_tokens, topk_num
|
|
)
|
|
|
|
# compute paddings
|
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
|
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
|
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
|
|
|
|
# init output tensors
|
|
sorted_token_ids = torch.full(
|
|
(max_loras * max_num_tokens_padded,),
|
|
topk_ids.numel(),
|
|
dtype=torch.int32,
|
|
device="cuda",
|
|
)
|
|
expert_ids = torch.full(
|
|
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
|
|
)
|
|
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
|
|
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
|
|
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
|
|
|
|
# call kernel
|
|
ops.moe_lora_align_block_size(
|
|
topk_ids,
|
|
token_lora_mapping,
|
|
num_experts,
|
|
block_size,
|
|
max_loras,
|
|
max_num_tokens_padded,
|
|
max_num_m_blocks,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_pad,
|
|
adapter_enabled,
|
|
lora_ids,
|
|
)
|
|
|
|
# verify values
|
|
expert_ids = expert_ids.view(max_loras, -1)
|
|
sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size)
|
|
|
|
for lora_idx in range(max_loras):
|
|
for token_idx in range(sorted_token_ids.size(1)):
|
|
block = sorted_token_ids[lora_idx][token_idx]
|
|
indices = block[block != topk_ids.numel()]
|
|
if indices.numel() > 0:
|
|
expert_id = expert_ids[lora_idx][token_idx]
|
|
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|