mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-01 00:30:05 +08:00
[V1] LoRA - Enable Serving Usecase (#12883)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
f0b2da72a8
commit
cbc40128eb
165
tests/lora/test_add_lora.py
Normal file
165
tests/lora/test_add_lora.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.inputs import TextPrompt
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import merge_async_iterators
|
||||||
|
|
||||||
|
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||||
|
LORA_MODULE_DOWNLOAD_PATH = None # Populated by download_and_prepare_lora_module() #noqa
|
||||||
|
LORA_RANK = 8
|
||||||
|
DEFAULT_MAX_LORAS = 16 * 3
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_prepare_lora_module():
|
||||||
|
"""
|
||||||
|
Request submission is expensive when the LoRA adapters have their own
|
||||||
|
tokenizers. This is because, for each request with a new LoRA adapter ID,
|
||||||
|
the front-end loads the tokenizer from disk.
|
||||||
|
|
||||||
|
In this test, as we are comparing request processing times, we want to
|
||||||
|
minimize any extra activity. To this effect, we download the LoRA
|
||||||
|
adapter and remove all the tokenizer files, so the engine will default
|
||||||
|
to the base model tokenizer.
|
||||||
|
"""
|
||||||
|
global LORA_MODULE_DOWNLOAD_PATH
|
||||||
|
|
||||||
|
LORA_MODULE_HF_PATH = "yard1/llama-2-7b-sql-lora-test"
|
||||||
|
LORA_MODULE_DOWNLOAD_PATH = snapshot_download(repo_id=LORA_MODULE_HF_PATH)
|
||||||
|
|
||||||
|
tokenizer_files = [
|
||||||
|
'added_tokens.json', 'tokenizer_config.json', 'tokenizer.json',
|
||||||
|
'tokenizer.model'
|
||||||
|
]
|
||||||
|
for tokenizer_file in tokenizer_files:
|
||||||
|
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
|
||||||
|
del_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines_lora):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_requests() -> List[LoRARequest]:
|
||||||
|
lora_requests: List[LoRARequest] = [
|
||||||
|
LoRARequest(lora_name=f"{i}",
|
||||||
|
lora_int_id=i,
|
||||||
|
lora_path=LORA_MODULE_DOWNLOAD_PATH)
|
||||||
|
for i in range(1, DEFAULT_MAX_LORAS + 1)
|
||||||
|
]
|
||||||
|
return lora_requests
|
||||||
|
|
||||||
|
|
||||||
|
async def requests_processing_time(llm,
|
||||||
|
lora_requests: List[LoRARequest]) -> float:
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(n=1,
|
||||||
|
temperature=0.0,
|
||||||
|
top_p=1.0,
|
||||||
|
ignore_eos=True,
|
||||||
|
max_tokens=1)
|
||||||
|
|
||||||
|
generators = []
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
for lora_request in lora_requests:
|
||||||
|
lora_int_id = lora_request.lora_int_id
|
||||||
|
generator = llm.generate(
|
||||||
|
prompt=TextPrompt(prompt=f"hello {lora_int_id}",
|
||||||
|
multi_modal_data=None), # type: ignore
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
request_id=f"test{lora_int_id}")
|
||||||
|
generators.append(generator)
|
||||||
|
|
||||||
|
all_gens = merge_async_iterators(*generators)
|
||||||
|
async for i, res in all_gens:
|
||||||
|
pass
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_lora():
|
||||||
|
"""
|
||||||
|
The add_lora function is used to pre-load some LoRA adapters into the
|
||||||
|
engine in anticipation of future requests using these adapters. To test
|
||||||
|
this functionality, we use the async engine to process some requests - We
|
||||||
|
do it twice, once with add_lora() pre-loading and once without.
|
||||||
|
|
||||||
|
We measure the request processing time in both cases and expect the time
|
||||||
|
to be lesser in the case with add_lora() calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
download_and_prepare_lora_module()
|
||||||
|
|
||||||
|
lora_requests: List[LoRARequest] = get_lora_requests()
|
||||||
|
|
||||||
|
max_loras = len(set([lr.lora_int_id for lr in lora_requests]))
|
||||||
|
# Create engine in eager-mode. Due to high max_loras, the CI can
|
||||||
|
# OOM during cuda-graph capture.
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=max_loras,
|
||||||
|
max_lora_rank=LORA_RANK,
|
||||||
|
max_model_len=128,
|
||||||
|
gpu_memory_utilization=0.8, #avoid OOM
|
||||||
|
enforce_eager=True)
|
||||||
|
|
||||||
|
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
|
||||||
|
# environment variable. reload vllm.enging.async_llm_engine as
|
||||||
|
# vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the
|
||||||
|
# env var.
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import vllm.engine.async_llm_engine
|
||||||
|
importlib.reload(vllm.engine.async_llm_engine)
|
||||||
|
from vllm.entrypoints.openai.api_server import (
|
||||||
|
build_async_engine_client_from_engine_args)
|
||||||
|
|
||||||
|
# split lora_requests into 3 parts
|
||||||
|
part_size = len(lora_requests) // 3
|
||||||
|
dummy_run_requests = lora_requests[:part_size]
|
||||||
|
warmup_run_requests = lora_requests[part_size:part_size * 2]
|
||||||
|
cold_run_requests = lora_requests[part_size * 2:]
|
||||||
|
|
||||||
|
async with build_async_engine_client_from_engine_args(engine_args) as llm:
|
||||||
|
|
||||||
|
# Dummy run - So any 1-time functionality like triton kernel compilation
|
||||||
|
# is complete here.
|
||||||
|
await requests_processing_time(llm, dummy_run_requests)
|
||||||
|
|
||||||
|
# Run with warmup
|
||||||
|
for lr in warmup_run_requests:
|
||||||
|
await llm.add_lora(lr)
|
||||||
|
# Wait for the add_lora function to complete on the server side.
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
time_with_add_lora = await requests_processing_time(
|
||||||
|
llm, warmup_run_requests)
|
||||||
|
|
||||||
|
# Run without any warmup
|
||||||
|
time_cold_start = await requests_processing_time(
|
||||||
|
llm, cold_run_requests)
|
||||||
|
|
||||||
|
print(f"time hot-start {time_with_add_lora} vs "
|
||||||
|
f"time cold-start {time_cold_start} ")
|
||||||
|
|
||||||
|
assert time_with_add_lora < time_cold_start, (
|
||||||
|
f"time_with_add_lora={time_with_add_lora}, "
|
||||||
|
f"time_cold_start={time_cold_start}"
|
||||||
|
"The engine request processing time with LoRA pre-loading "
|
||||||
|
"must be less than the version that does on-demand LoRA loading.")
|
||||||
@ -134,3 +134,4 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
PROFILE = b'\x02'
|
PROFILE = b'\x02'
|
||||||
RESET_PREFIX_CACHE = b'\x03'
|
RESET_PREFIX_CACHE = b'\x03'
|
||||||
|
ADD_LORA = b'\x04'
|
||||||
|
|||||||
@ -361,6 +361,10 @@ class AsyncLLM(EngineClient):
|
|||||||
async def reset_prefix_cache(self) -> None:
|
async def reset_prefix_cache(self) -> None:
|
||||||
await self.engine_core.reset_prefix_cache_async()
|
await self.engine_core.reset_prefix_cache_async()
|
||||||
|
|
||||||
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
|
await self.engine_core.add_lora_async(lora_request)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -376,7 +380,3 @@ class AsyncLLM(EngineClient):
|
|||||||
@property
|
@property
|
||||||
def dead_error(self) -> BaseException:
|
def dead_error(self) -> BaseException:
|
||||||
return Exception() # TODO: implement
|
return Exception() # TODO: implement
|
||||||
|
|
||||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
|
||||||
"""Load a new LoRA adapter into the engine for future requests."""
|
|
||||||
raise NotImplementedError("LoRA not yet supported in V1")
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import zmq.asyncio
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
maybe_register_config_serialize_by_value)
|
maybe_register_config_serialize_by_value)
|
||||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||||
@ -146,6 +147,9 @@ class EngineCore:
|
|||||||
def reset_prefix_cache(self):
|
def reset_prefix_cache(self):
|
||||||
self.scheduler.reset_prefix_cache()
|
self.scheduler.reset_prefix_cache()
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
self.model_executor.add_lora(lora_request)
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreProc(EngineCore):
|
class EngineCoreProc(EngineCore):
|
||||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||||
@ -262,12 +266,15 @@ class EngineCoreProc(EngineCore):
|
|||||||
self.reset_prefix_cache()
|
self.reset_prefix_cache()
|
||||||
elif request_type == EngineCoreRequestType.PROFILE:
|
elif request_type == EngineCoreRequestType.PROFILE:
|
||||||
self.model_executor.profile(request)
|
self.model_executor.profile(request)
|
||||||
|
elif request_type == EngineCoreRequestType.ADD_LORA:
|
||||||
|
self.model_executor.add_lora(request)
|
||||||
|
|
||||||
def process_input_socket(self, input_path: str):
|
def process_input_socket(self, input_path: str):
|
||||||
"""Input socket IO thread."""
|
"""Input socket IO thread."""
|
||||||
|
|
||||||
# Msgpack serialization decoding.
|
# Msgpack serialization decoding.
|
||||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||||
|
add_lora_decoder = MsgpackDecoder(LoRARequest)
|
||||||
generic_decoder = MsgpackDecoder()
|
generic_decoder = MsgpackDecoder()
|
||||||
|
|
||||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||||
@ -277,9 +284,14 @@ class EngineCoreProc(EngineCore):
|
|||||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||||
|
|
||||||
# Deserialize the request data.
|
# Deserialize the request data.
|
||||||
decoder = add_request_decoder if (
|
decoder = None
|
||||||
request_type
|
if request_type == EngineCoreRequestType.ADD:
|
||||||
== EngineCoreRequestType.ADD) else generic_decoder
|
decoder = add_request_decoder
|
||||||
|
elif request_type == EngineCoreRequestType.ADD_LORA:
|
||||||
|
decoder = add_lora_decoder
|
||||||
|
else:
|
||||||
|
decoder = generic_decoder
|
||||||
|
|
||||||
request = decoder.decode(data_frame.buffer)
|
request = decoder.decode(data_frame.buffer)
|
||||||
|
|
||||||
# Push to input queue for core busy loop.
|
# Push to input queue for core busy loop.
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import zmq.asyncio
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||||
make_zmq_socket)
|
make_zmq_socket)
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
@ -77,6 +78,9 @@ class EngineCoreClient(ABC):
|
|||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_output_async(self) -> EngineCoreOutputs:
|
async def get_output_async(self) -> EngineCoreOutputs:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -92,6 +96,9 @@ class EngineCoreClient(ABC):
|
|||||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class InprocClient(EngineCoreClient):
|
class InprocClient(EngineCoreClient):
|
||||||
"""
|
"""
|
||||||
@ -125,6 +132,9 @@ class InprocClient(EngineCoreClient):
|
|||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self) -> None:
|
||||||
self.engine_core.reset_prefix_cache()
|
self.engine_core.reset_prefix_cache()
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
self.engine_core.add_lora(lora_request)
|
||||||
|
|
||||||
|
|
||||||
class MPClient(EngineCoreClient):
|
class MPClient(EngineCoreClient):
|
||||||
"""
|
"""
|
||||||
@ -242,6 +252,9 @@ class SyncMPClient(MPClient):
|
|||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self) -> None:
|
||||||
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
|
||||||
|
|
||||||
|
|
||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||||
@ -295,3 +308,6 @@ class AsyncMPClient(MPClient):
|
|||||||
|
|
||||||
async def reset_prefix_cache_async(self) -> None:
|
async def reset_prefix_cache_async(self) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||||
|
|
||||||
|
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
||||||
|
await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils import GiB_bytes
|
||||||
@ -234,6 +235,9 @@ class Worker(WorkerBase):
|
|||||||
else:
|
else:
|
||||||
self.profiler.stop()
|
self.profiler.stop()
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
return self.model_runner.add_lora(lora_request)
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# worker will always be healthy as long as it's running.
|
# worker will always be healthy as long as it's running.
|
||||||
return
|
return
|
||||||
|
|||||||
@ -127,3 +127,8 @@ class LoRAModelRunnerMixin:
|
|||||||
|
|
||||||
# __exit__ code
|
# __exit__ code
|
||||||
self.lora_manager.remove_all_adapters()
|
self.lora_manager.remove_all_adapters()
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.add_adapter(lora_request)
|
||||||
Loading…
x
Reference in New Issue
Block a user