mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Core] LoRA V1 - Add add/pin/list/remove_lora functions (#13705)
This commit is contained in:
parent
4d251ad00e
commit
03f48b3db6
@ -7,6 +7,7 @@ from typing import List
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm.envs as env
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -144,10 +145,14 @@ async def test_add_lora():
|
||||
await requests_processing_time(llm, dummy_run_requests)
|
||||
|
||||
# Run with warmup
|
||||
for lr in warmup_run_requests:
|
||||
await llm.add_lora(lr)
|
||||
# Wait for the add_lora function to complete on the server side.
|
||||
await asyncio.sleep(30)
|
||||
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
|
||||
add_lora_results = await asyncio.gather(*add_lora_tasks)
|
||||
if env.VLLM_USE_V1:
|
||||
# Test that all all_lora calls are successful.
|
||||
assert all(add_lora_results)
|
||||
else:
|
||||
# No way to check V0 engine results as the calls just return None.
|
||||
pass
|
||||
time_with_add_lora = await requests_processing_time(
|
||||
llm, warmup_run_requests)
|
||||
|
||||
|
||||
137
tests/lora/test_lora_functions.py
Normal file
137
tests/lora/test_lora_functions.py
Normal file
@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Script to test add_lora, remove_lora, pin_lora, list_loras functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||
LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
|
||||
LORA_RANK = 8
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
def make_lora_request(lora_id: int):
|
||||
return LoRARequest(lora_name=f"{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path=LORA_MODULE_PATH)
|
||||
|
||||
|
||||
def test_lora_functions_sync():
|
||||
|
||||
max_loras = 4
|
||||
# Create engine in eager-mode. Due to high max_loras, the CI can
|
||||
# OOM during cuda-graph capture.
|
||||
engine_args = EngineArgs(model=MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_loras=max_loras,
|
||||
max_lora_rank=LORA_RANK,
|
||||
max_model_len=128,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True)
|
||||
|
||||
llm = LLM.get_engine_class().from_engine_args(engine_args)
|
||||
|
||||
def run_check(fn, args, expected: List):
|
||||
fn(args)
|
||||
assert set(llm.list_loras()) == set(expected)
|
||||
|
||||
run_check(llm.add_lora, make_lora_request(1), [1])
|
||||
run_check(llm.add_lora, make_lora_request(2), [1, 2])
|
||||
|
||||
# Pin LoRA 1 and test that it is never removed on subsequent adds.
|
||||
run_check(llm.pin_lora, 1, [1, 2])
|
||||
run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
|
||||
run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
|
||||
run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
|
||||
run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
|
||||
run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
|
||||
run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
|
||||
run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
|
||||
run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])
|
||||
|
||||
# Remove LoRA 1 and continue adding.
|
||||
run_check(llm.remove_lora, 1, [8, 9, 10])
|
||||
run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
|
||||
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
|
||||
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
|
||||
|
||||
# Remove all LoRAs
|
||||
run_check(llm.remove_lora, 13, [12, 10, 11])
|
||||
run_check(llm.remove_lora, 12, [10, 11])
|
||||
run_check(llm.remove_lora, 11, [10])
|
||||
run_check(llm.remove_lora, 10, [])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lora_functions_async():
|
||||
|
||||
if os.getenv("VLLM_USE_V1") == "0":
|
||||
pytest.skip(
|
||||
reason=
|
||||
"V0 AsyncLLMEngine does not expose remove/list/pin LoRA functions")
|
||||
|
||||
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
|
||||
# environment variable. reload vllm.enging.async_llm_engine as
|
||||
# vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the
|
||||
# env var.
|
||||
import importlib
|
||||
|
||||
import vllm.engine.async_llm_engine
|
||||
importlib.reload(vllm.engine.async_llm_engine)
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
|
||||
max_loras = 4
|
||||
engine_args = AsyncEngineArgs(model=MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_loras=max_loras,
|
||||
max_lora_rank=LORA_RANK,
|
||||
max_model_len=128,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True)
|
||||
|
||||
async def run_check(fn, args, expected: List):
|
||||
await fn(args)
|
||||
assert set(await llm.list_loras()) == set(expected)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(engine_args) as llm:
|
||||
await run_check(llm.add_lora, make_lora_request(1), [1])
|
||||
await run_check(llm.add_lora, make_lora_request(2), [1, 2])
|
||||
|
||||
# Pin LoRA 1 and test that it is never removed on subsequent adds.
|
||||
await run_check(llm.pin_lora, 1, [1, 2])
|
||||
await run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
|
||||
await run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
|
||||
await run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
|
||||
await run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
|
||||
await run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
|
||||
await run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
|
||||
await run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
|
||||
await run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])
|
||||
|
||||
# Remove LoRA 1 and continue adding.
|
||||
await run_check(llm.remove_lora, 1, [8, 9, 10])
|
||||
await run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
|
||||
await run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
|
||||
await run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
|
||||
|
||||
# Remove all LoRAs
|
||||
await run_check(llm.remove_lora, 13, [12, 10, 11])
|
||||
await run_check(llm.remove_lora, 12, [10, 11])
|
||||
await run_check(llm.remove_lora, 11, [10])
|
||||
await run_check(llm.remove_lora, 10, [])
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
|
||||
from typing import AsyncGenerator, List, Mapping, Optional, Set, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -392,9 +392,21 @@ class AsyncLLM(EngineClient):
|
||||
async def wake_up(self) -> None:
|
||||
await self.engine_core.wake_up_async()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
await self.engine_core.add_lora_async(lora_request)
|
||||
return await self.engine_core.add_lora_async(lora_request)
|
||||
|
||||
async def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return await self.engine_core.remove_lora_async(lora_id)
|
||||
|
||||
async def list_loras(self) -> Set[int]:
|
||||
"""List all registered adapters."""
|
||||
return await self.engine_core.list_loras_async()
|
||||
|
||||
async def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return await self.engine_core.pin_lora_async(lora_id)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
|
||||
@ -7,7 +7,7 @@ import time
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
from typing import Any, List, Optional, Set, Tuple, Type
|
||||
|
||||
import msgspec
|
||||
import psutil
|
||||
@ -222,8 +222,17 @@ class EngineCore:
|
||||
def execute_dummy_batch(self):
|
||||
self.model_executor.collective_rpc("execute_dummy_batch")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self.model_executor.add_lora(lora_request)
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_executor.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
|
||||
@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -97,7 +97,16 @@ class EngineCoreClient(ABC):
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
@ -121,7 +130,16 @@ class EngineCoreClient(ABC):
|
||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
||||
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def remove_lora_async(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_loras_async(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -166,8 +184,17 @@ class InprocClient(EngineCoreClient):
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.engine_core.execute_dummy_batch()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self.engine_core.add_lora(lora_request)
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.engine_core.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.engine_core.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.engine_core.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -356,8 +383,17 @@ class SyncMPClient(MPClient):
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self._call_utility("reset_prefix_cache")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self._call_utility("add_lora", lora_request)
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self._call_utility("add_lora", lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self._call_utility("remove_lora", lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self._call_utility("list_loras")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self._call_utility("pin_lora", lora_id)
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
self._call_utility("sleep", level)
|
||||
@ -454,5 +490,14 @@ class AsyncMPClient(MPClient):
|
||||
async def execute_dummy_batch_async(self) -> None:
|
||||
await self._call_utility_async("execute_dummy_batch")
|
||||
|
||||
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
||||
await self._call_utility_async("add_lora", lora_request)
|
||||
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
|
||||
return await self._call_utility_async("add_lora", lora_request)
|
||||
|
||||
async def remove_lora_async(self, lora_id: int) -> bool:
|
||||
return await self._call_utility_async("remove_lora", lora_id)
|
||||
|
||||
async def list_loras_async(self) -> Set[int]:
|
||||
return await self._call_utility_async("list_loras")
|
||||
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
return await self._call_utility_async("pin_lora", lora_id)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Mapping, Optional, Type, Union
|
||||
from typing import Dict, List, Mapping, Optional, Set, Type, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
@ -254,3 +254,19 @@ class LLMEngine:
|
||||
f"found type: {type(tokenizer_group)}")
|
||||
|
||||
return tokenizer_group
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return self.engine_core.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return self.engine_core.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
"""List all registered adapters."""
|
||||
return self.engine_core.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -240,6 +240,15 @@ class Worker(WorkerBase):
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
@ -131,4 +131,19 @@ class LoRAModelRunnerMixin:
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
Loading…
x
Reference in New Issue
Block a user