diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 2b421bfd9eb8..70b058b201d6 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -7,6 +7,7 @@ from typing import List import pytest from huggingface_hub import snapshot_download +import vllm.envs as env from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest @@ -144,10 +145,14 @@ async def test_add_lora(): await requests_processing_time(llm, dummy_run_requests) # Run with warmup - for lr in warmup_run_requests: - await llm.add_lora(lr) - # Wait for the add_lora function to complete on the server side. - await asyncio.sleep(30) + add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests] + add_lora_results = await asyncio.gather(*add_lora_tasks) + if env.VLLM_USE_V1: + # Test that all all_lora calls are successful. + assert all(add_lora_results) + else: + # No way to check V0 engine results as the calls just return None. + pass time_with_add_lora = await requests_processing_time( llm, warmup_run_requests) diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py new file mode 100644 index 000000000000..1309848868b4 --- /dev/null +++ b/tests/lora/test_lora_functions.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Script to test add_lora, remove_lora, pin_lora, list_loras functions. +""" + +import os +from typing import List + +import pytest + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.llm import LLM +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" +LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test" +LORA_RANK = 8 + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +def make_lora_request(lora_id: int): + return LoRARequest(lora_name=f"{lora_id}", + lora_int_id=lora_id, + lora_path=LORA_MODULE_PATH) + + +def test_lora_functions_sync(): + + max_loras = 4 + # Create engine in eager-mode. Due to high max_loras, the CI can + # OOM during cuda-graph capture. + engine_args = EngineArgs(model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True) + + llm = LLM.get_engine_class().from_engine_args(engine_args) + + def run_check(fn, args, expected: List): + fn(args) + assert set(llm.list_loras()) == set(expected) + + run_check(llm.add_lora, make_lora_request(1), [1]) + run_check(llm.add_lora, make_lora_request(2), [1, 2]) + + # Pin LoRA 1 and test that it is never removed on subsequent adds. + run_check(llm.pin_lora, 1, [1, 2]) + run_check(llm.add_lora, make_lora_request(3), [1, 2, 3]) + run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4]) + run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4]) + run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4]) + run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7]) + run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7]) + run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7]) + run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10]) + + # Remove LoRA 1 and continue adding. + run_check(llm.remove_lora, 1, [8, 9, 10]) + run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11]) + run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11]) + run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11]) + + # Remove all LoRAs + run_check(llm.remove_lora, 13, [12, 10, 11]) + run_check(llm.remove_lora, 12, [10, 11]) + run_check(llm.remove_lora, 11, [10]) + run_check(llm.remove_lora, 10, []) + + +@pytest.mark.asyncio +async def test_lora_functions_async(): + + if os.getenv("VLLM_USE_V1") == "0": + pytest.skip( + reason= + "V0 AsyncLLMEngine does not expose remove/list/pin LoRA functions") + + # The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1` + # environment variable. reload vllm.enging.async_llm_engine as + # vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the + # env var. + import importlib + + import vllm.engine.async_llm_engine + importlib.reload(vllm.engine.async_llm_engine) + from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) + + max_loras = 4 + engine_args = AsyncEngineArgs(model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True) + + async def run_check(fn, args, expected: List): + await fn(args) + assert set(await llm.list_loras()) == set(expected) + + async with build_async_engine_client_from_engine_args(engine_args) as llm: + await run_check(llm.add_lora, make_lora_request(1), [1]) + await run_check(llm.add_lora, make_lora_request(2), [1, 2]) + + # Pin LoRA 1 and test that it is never removed on subsequent adds. + await run_check(llm.pin_lora, 1, [1, 2]) + await run_check(llm.add_lora, make_lora_request(3), [1, 2, 3]) + await run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4]) + await run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4]) + await run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4]) + await run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7]) + await run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7]) + await run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7]) + await run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10]) + + # Remove LoRA 1 and continue adding. + await run_check(llm.remove_lora, 1, [8, 9, 10]) + await run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11]) + await run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11]) + await run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11]) + + # Remove all LoRAs + await run_check(llm.remove_lora, 13, [12, 10, 11]) + await run_check(llm.remove_lora, 12, [10, 11]) + await run_check(llm.remove_lora, 11, [10]) + await run_check(llm.remove_lora, 10, []) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 36a02628f405..0c04e14cec2f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -2,7 +2,7 @@ import asyncio import os -from typing import AsyncGenerator, List, Mapping, Optional, Type, Union +from typing import AsyncGenerator, List, Mapping, Optional, Set, Type, Union import numpy as np @@ -392,9 +392,21 @@ class AsyncLLM(EngineClient): async def wake_up(self) -> None: await self.engine_core.wake_up_async() - async def add_lora(self, lora_request: LoRARequest) -> None: + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" - await self.engine_core.add_lora_async(lora_request) + return await self.engine_core.add_lora_async(lora_request) + + async def remove_lora(self, lora_id: int) -> bool: + """Remove an already loaded LoRA adapter.""" + return await self.engine_core.remove_lora_async(lora_id) + + async def list_loras(self) -> Set[int]: + """List all registered adapters.""" + return await self.engine_core.list_loras_async() + + async def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + return await self.engine_core.pin_lora_async(lora_id) @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 85c97293af8b..041896f1c7cc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,7 +7,7 @@ import time from concurrent.futures import Future from inspect import isclass, signature from multiprocessing.connection import Connection -from typing import Any, List, Optional, Tuple, Type +from typing import Any, List, Optional, Set, Tuple, Type import msgspec import psutil @@ -222,8 +222,17 @@ class EngineCore: def execute_dummy_batch(self): self.model_executor.collective_rpc("execute_dummy_batch") - def add_lora(self, lora_request: LoRARequest) -> None: - self.model_executor.add_lora(lora_request) + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) class EngineCoreProc(EngineCore): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 5ffaf63e6cec..9f36e11d12d7 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Set, Type, Union import zmq import zmq.asyncio @@ -97,7 +97,16 @@ class EngineCoreClient(ABC): def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError - def add_lora(self, lora_request: LoRARequest) -> None: + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def list_loras(self) -> Set[int]: + raise NotImplementedError + + def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError async def get_output_async(self) -> EngineCoreOutputs: @@ -121,7 +130,16 @@ class EngineCoreClient(ABC): async def abort_requests_async(self, request_ids: List[str]) -> None: raise NotImplementedError - async def add_lora_async(self, lora_request: LoRARequest) -> None: + async def add_lora_async(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + async def remove_lora_async(self, lora_id: int) -> bool: + raise NotImplementedError + + async def list_loras_async(self) -> Set[int]: + raise NotImplementedError + + async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError @@ -166,8 +184,17 @@ class InprocClient(EngineCoreClient): def execute_dummy_batch(self) -> None: self.engine_core.execute_dummy_batch() - def add_lora(self, lora_request: LoRARequest) -> None: - self.engine_core.add_lora(lora_request) + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.engine_core.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.engine_core.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.engine_core.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.engine_core.pin_lora(lora_id) @dataclass @@ -356,8 +383,17 @@ class SyncMPClient(MPClient): def reset_prefix_cache(self) -> None: self._call_utility("reset_prefix_cache") - def add_lora(self, lora_request: LoRARequest) -> None: - self._call_utility("add_lora", lora_request) + def add_lora(self, lora_request: LoRARequest) -> bool: + return self._call_utility("add_lora", lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self._call_utility("remove_lora", lora_id) + + def list_loras(self) -> Set[int]: + return self._call_utility("list_loras") + + def pin_lora(self, lora_id: int) -> bool: + return self._call_utility("pin_lora", lora_id) def sleep(self, level: int = 1) -> None: self._call_utility("sleep", level) @@ -454,5 +490,14 @@ class AsyncMPClient(MPClient): async def execute_dummy_batch_async(self) -> None: await self._call_utility_async("execute_dummy_batch") - async def add_lora_async(self, lora_request: LoRARequest) -> None: - await self._call_utility_async("add_lora", lora_request) + async def add_lora_async(self, lora_request: LoRARequest) -> bool: + return await self._call_utility_async("add_lora", lora_request) + + async def remove_lora_async(self, lora_id: int) -> bool: + return await self._call_utility_async("remove_lora", lora_id) + + async def list_loras_async(self) -> Set[int]: + return await self._call_utility_async("list_loras") + + async def pin_lora_async(self, lora_id: int) -> bool: + return await self._call_utility_async("pin_lora", lora_id) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 64fd8719c82e..ccf52250c1d6 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Mapping, Optional, Type, Union +from typing import Dict, List, Mapping, Optional, Set, Type, Union from typing_extensions import TypeVar @@ -254,3 +254,19 @@ class LLMEngine: f"found type: {type(tokenizer_group)}") return tokenizer_group + + def add_lora(self, lora_request: LoRARequest) -> bool: + """Load a new LoRA adapter into the engine for future requests.""" + return self.engine_core.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + """Remove an already loaded LoRA adapter.""" + return self.engine_core.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + """List all registered adapters.""" + return self.engine_core.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + return self.engine_core.pin_lora(lora_id) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a14a7082df4b..f681925f557e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Set import torch import torch.distributed @@ -240,6 +240,15 @@ class Worker(WorkerBase): def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + def check_health(self) -> None: # worker will always be healthy as long as it's running. return diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 053897da0aa7..731e758e6e74 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -131,4 +131,19 @@ class LoRAModelRunnerMixin: def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) \ No newline at end of file + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() \ No newline at end of file