vllm/tests/lora/test_worker.py
Harry Mellor f36355abfd
Move LoadConfig from config/__init__.py to config/load.py (#24566)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-09-10 06:14:18 -07:00

89 lines
2.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import tempfile
from unittest.mock import patch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.config.load import LoadConfig
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.v1.worker.gpu_worker import Worker
NUM_LORAS = 16
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
lora_mapping = LoRAMapping([], [])
worker.model_runner.lora_manager.set_active_adapters(
lora_requests, lora_mapping)
vllm_config = VllmConfig(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
seed=0,
dtype="float16",
enforce_eager=True,
),
load_config=LoadConfig(
download_dir=None,
load_format="dummy",
),
parallel_config=ParallelConfig(
pipeline_parallel_size=1,
tensor_parallel_size=1,
data_parallel_size=1,
),
scheduler_config=SchedulerConfig("generate", 32, 32, 32),
device_config=DeviceConfig("cuda"),
cache_config=CacheConfig(
block_size=16,
swap_space=0,
cache_dtype="auto",
),
lora_config=LoRAConfig(max_lora_rank=8,
max_cpu_loras=NUM_LORAS,
max_loras=NUM_LORAS),
)
worker = Worker(
vllm_config=vllm_config,
local_rank=0,
rank=0,
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_device()
worker.load_model()
set_active_loras(worker, [])
assert worker.list_loras() == set()
lora_requests = [
LoRARequest(str(i + 1), i + 1, sql_lora_files)
for i in range(NUM_LORAS)
]
set_active_loras(worker, lora_requests)
assert worker.list_loras() == {
lora_request.lora_int_id
for lora_request in lora_requests
}
for i in range(NUM_LORAS):
random.seed(i)
iter_lora_requests = random.choices(lora_requests,
k=random.randint(1, NUM_LORAS))
random.shuffle(iter_lora_requests)
iter_lora_requests = iter_lora_requests[:-random.randint(0, NUM_LORAS)]
set_active_loras(worker, lora_requests)
assert worker.list_loras().issuperset(
{lora_request.lora_int_id
for lora_request in iter_lora_requests})