[Misc] LoRA - Refactor Punica ops tests (#12970)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-02-11 12:56:03 +05:30 committed by GitHub
parent c320ca8edd
commit 78a141d768
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 686 additions and 725 deletions

View File

@ -0,0 +1,652 @@
# SPDX-License-Identifier: Apache-2.0
from threading import Lock
from typing import List
import pytest
import torch
import vllm.lora.ops.triton_ops # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from .utils import (PunicaTensors, assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices)
# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
nslices: int, inputs_tensor: torch.Tensor,
lora_weights_lst: List[torch.Tensor], out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
num_tokens: int, scaling: float):
"""
Wrapper around sgmv_shrink that handles any nslices.
"""
for index in range(nslices):
sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
out_tensor[index],
b_seq_start_loc,
seq_len_tensor,
prompt_lora_mapping,
batches,
max_seq_length,
num_tokens,
scaling,
)
def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
inputs_tensor: torch.Tensor,
lora_weights_lst: List[torch.Tensor],
out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor, batches: int,
max_seq_length: int, num_tokens: int,
add_inputs: bool) -> None:
"""
Wrapper around sgmv_expand that handles any nslices.
"""
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
out_tensor,
b_seq_start_loc,
seq_len_tensor,
prompt_lora_mapping,
batches,
max_seq_length,
num_tokens,
add_inputs=add_inputs,
)
else:
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
out_tensor,
b_seq_start_loc,
seq_len_tensor,
prompt_lora_mapping,
batches,
max_seq_length,
num_tokens,
slice_offset,
hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
_dict_lock = Lock()
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
"""
Compare outputs of vllm.sgmv_shrink kernel against a reference
implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
"shrink",
device,
)
max_seq_length, token_nums = data.meta()
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
scaling,
)
sgmv_shrink_for_nslices(
nslices,
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
scaling,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
"""
Compare outputs of vllm.sgmv_expand kernel against a reference
implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
"expand",
device,
)
max_seq_length, token_nums = data.meta()
with _dict_lock:
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
offset_start=0,
add_inputs=add_inputs,
)
sgmv_expand_for_nslices(nslices,
hidden_size,
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
add_inputs=add_inputs)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
scaling: float):
"""
Compare vllm.bgmv_shrink against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"shrink",
device,
)
torch.ops.vllm.bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
scaling,
)
bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
scaling,
)
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
add_inputs: bool):
"""
Compare vllm.bgmv_expand against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"expand",
device,
)
torch.ops.vllm.bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, add_inputs: bool):
"""
Compare vllm.bgmv_expand_slice against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
torch.ops.vllm.bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.our_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.ref_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
assert_close(data.our_out_tensor, data.ref_out_tensor)
# Tests
# We test the punica kernels along 2 verticals mainly.
# 1. Variations in hidden_dim size
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
# etc.)
# We have collected the hidden_sizes included in the LoRA models
# currently supported by vLLM. It tests whether the corresponding Triton
# kernel can run normally when tensor parallelism is set to
# [1, 2, 4, 8, 16, 32, 64].
HIDDEN_SIZES = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
#The size of TP
divisibility = [1, 2, 8, 16, 64]
all_hidden_size = []
for div in divisibility:
for hidden_size in HIDDEN_SIZES:
all_hidden_size.append(hidden_size // div)
HIDDEN_SIZES = list(set(all_hidden_size))
# Test params that focuses on hidden_size variation.
hs_test_params = {
"hidden_sizes": HIDDEN_SIZES,
"batches": [4],
"num_loras": [4],
"max_ranks": [32],
}
# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
"hidden_sizes": [2049],
"batches": [1, 4, 16, 32],
"num_loras": [1, 8, 32, 128],
"max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
}
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [0]
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv_hidden_size(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv_hidden_size(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str,
seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
rank: int, hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str, seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)

View File

@ -1,401 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
from threading import Lock
import pytest
import torch
import vllm.lora.ops.triton_ops # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from .utils import (assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices)
HIDDEN_SIZES = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
#The size of TP
divisibility = [1, 2, 8, 16, 64]
all_hidden_size = []
for div in divisibility:
for hidden_size in HIDDEN_SIZES:
all_hidden_size.append(hidden_size // div)
HIDDEN_SIZES = list(set(all_hidden_size))
BATCHES = [4]
NUM_LORA = [4]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [32]
SCALES = [0.5]
SEED = [0]
DEVICES = [f"cuda:{0}"]
_dict_lock = Lock()
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 128
(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
for index in range(nslices):
sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
ref_out_tensor[index],
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
offset_start=0,
add_inputs=True,
)
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
else:
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
torch.ops.vllm.bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
scaling,
)
else:
torch.ops.vllm.bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
bgmv_expand(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
add_inputs=True,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
torch.ops.vllm.bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)

View File

@ -1,317 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
from threading import Lock
import pytest
import torch
# Enable custom op register
import vllm.lora.ops.triton_ops # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from .utils import (assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices)
HIDDEN_SIZES = [2049]
BATCHES = [1, 4, 16, 32]
NUM_LORA = [1, 8, 32, 128]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
SCALES = [0.5]
SEED = [0]
DEVICES = [f"cuda:{0}"]
_dict_lock = Lock()
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 128
(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
for index in range(nslices):
sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
ref_out_tensor[index],
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
offset_start=0,
add_inputs=True,
)
slice_offset = 0
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
else:
for index in range(nslices):
lora_weights = lora_weights_lst[index]
sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
torch.ops.vllm.bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
scaling,
)
else:
torch.ops.vllm.bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
bgmv_expand(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
add_inputs=True,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
torch.ops.vllm.bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
@ -106,6 +107,31 @@ def assert_close(a, b):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@dataclass
class PunicaTensors:
inputs_tensor: torch.Tensor
lora_weights: Union[torch.Tensor, List[torch.Tensor]]
our_out_tensor: torch.Tensor
ref_out_tensor: torch.Tensor
b_seq_start_loc: torch.Tensor
prompt_lora_mapping: torch.Tensor
seq_len_tensor: torch.Tensor
token_lora_mapping: torch.Tensor
def meta(self) -> Tuple[int, int]:
"""
Infer max_seq_length and token_nums from the tensors
and return them.
"""
max_seq_length = self.seq_len_tensor.max()
token_nums = self.seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
return max_seq_length, token_nums
def generate_data(
batches,
hidden_size,
@ -115,7 +141,7 @@ def generate_data(
dtype,
op_type,
device,
):
) -> PunicaTensors:
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
@ -164,7 +190,8 @@ def generate_data(
indices[current_offset:current_offset +
seq_len_tensor[b_id]].copy_(lora_index)
current_offset += seq_len_tensor[b_id].item()
return (
return PunicaTensors(
inputs_tensor,
lora_weights,
our_out_tensor,
@ -185,7 +212,7 @@ def generate_data_for_expand_nslices(
dtype,
nslices,
device,
):
) -> PunicaTensors:
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
@ -222,7 +249,7 @@ def generate_data_for_expand_nslices(
current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device)
return (
return PunicaTensors(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
@ -244,7 +271,7 @@ def generate_data_for_nslices(
dtype,
op_type,
device,
):
) -> PunicaTensors:
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
@ -302,7 +329,7 @@ def generate_data_for_nslices(
current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device)
return (
return PunicaTensors(
inputs_tensor,
lora_weights_lst,
our_out_tensor,