# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import random import tempfile from unittest.mock import patch from vllm.config import ( CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, ) from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker NUM_LORAS = 16 @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) worker.model_runner.lora_manager.set_active_adapters( lora_requests, lora_mapping ) vllm_config = VllmConfig( model_config=ModelConfig( "meta-llama/Llama-2-7b-hf", seed=0, dtype="float16", enforce_eager=True, ), load_config=LoadConfig( download_dir=None, load_format="dummy", ), parallel_config=ParallelConfig( pipeline_parallel_size=1, tensor_parallel_size=1, data_parallel_size=1, ), scheduler_config=SchedulerConfig("generate", 32, 32, 32), device_config=DeviceConfig("cuda"), cache_config=CacheConfig( block_size=16, swap_space=0, cache_dtype="auto", ), lora_config=LoRAConfig( max_lora_rank=8, max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS ), ) worker = Worker( vllm_config=vllm_config, local_rank=0, rank=0, distributed_init_method=f"file://{tempfile.mkstemp()[1]}", ) worker.init_device() worker.load_model() set_active_loras(worker, []) assert worker.list_loras() == set() lora_requests = [ LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS) ] set_active_loras(worker, lora_requests) assert worker.list_loras() == { lora_request.lora_int_id for lora_request in lora_requests } for i in range(NUM_LORAS): random.seed(i) iter_lora_requests = random.choices( lora_requests, k=random.randint(1, NUM_LORAS) ) random.shuffle(iter_lora_requests) iter_lora_requests = iter_lora_requests[: -random.randint(0, NUM_LORAS)] set_active_loras(worker, lora_requests) assert worker.list_loras().issuperset( {lora_request.lora_int_id for lora_request in iter_lora_requests} )