multi-LoRA as extra models in OpenAI server (#2775)

how to serve the loras (mimicking the [multilora inference example](https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py)):
```terminal
$ export LORA_PATH=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
$ python -m vllm.entrypoints.api_server \
 --model meta-llama/Llama-2-7b-hf \
 --enable-lora \
 --lora-modules sql-lora=$LORA_PATH sql-lora2=$LORA_PATH
```
the above server will list 3 separate values if the user queries `/models`: one for the base served model, and one each for the specified lora modules. in this case sql-lora and sql-lora2 point to the same underlying lora, but this need not be the case. lora config values take the same values they do in EngineArgs

no work has been done here to scope client permissions to specific models
This commit is contained in:
jvmncs 2024-02-17 15:00:48 -05:00 committed by GitHub
parent 185b2c29e2
commit 8f36444c4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 200 additions and 27 deletions

View File

@ -50,3 +50,42 @@ the third parameter is the path to the LoRA adapter.
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_ Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
Serving LoRA Adapters
---------------------
LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
.. code-block:: bash
python -m vllm.entrypoints.api_server \
--model meta-llama/Llama-2-7b-hf \
--enable-lora \
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``,
etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along
with its base model:
.. code-block:: bash
curl localhost:8000/v1/models | jq .
{
"object": "list",
"data": [
{
"id": "meta-llama/Llama-2-7b-hf",
"object": "model",
...
},
{
"id": "sql-lora",
"object": "model",
...
}
]
}
Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).

View File

@ -12,7 +12,9 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: def create_test_prompts(
lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters. """Create a list of test prompts with their sampling parameters.
2 requests for base model, 4 requests for the LoRA. We define 2 2 requests for base model, 4 requests for the LoRA. We define 2

View File

@ -7,9 +7,11 @@ import pytest
import requests import requests
import ray # using Ray for overall ease of process management, parallel requests, and debugging. import ray # using Ray for overall ease of process management, parallel requests, and debugging.
import openai # use the official client for correctness check import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -54,7 +56,12 @@ class ServerRunner:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def server(): def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="session")
def server(zephyr_lora_files):
ray.init() ray.init()
server_runner = ServerRunner.remote([ server_runner = ServerRunner.remote([
"--model", "--model",
@ -64,6 +71,17 @@ def server():
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager", "--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128"
]) ])
ray.get(server_runner.ready.remote()) ray.get(server_runner.ready.remote())
yield server_runner yield server_runner
@ -79,8 +97,25 @@ def client():
yield client yield client
async def test_single_completion(server, client: openai.AsyncOpenAI): async def test_check_models(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME, models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
temperature=0.0) temperature=0.0)
@ -104,7 +139,13 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
completion.choices[0].text) >= 5 completion.choices[0].text) >= 5
async def test_single_chat_session(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_single_chat_session(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -115,7 +156,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=model_name,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
) )
@ -139,11 +180,17 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
async def test_completion_streaming(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_streaming(server, client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is an LLM?" prompt = "What is an LLM?"
single_completion = await client.completions.create( single_completion = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=prompt, prompt=prompt,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
@ -152,7 +199,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
single_usage = single_completion.usage single_usage = single_completion.usage
stream = await client.completions.create( stream = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=prompt, prompt=prompt,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
@ -166,7 +213,13 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == single_output assert "".join(chunks) == single_output
async def test_chat_streaming(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_chat_streaming(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -177,7 +230,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=model_name,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.0, temperature=0.0,
@ -187,7 +240,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
# test streaming # test streaming
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
model=MODEL_NAME, model=model_name,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.0, temperature=0.0,
@ -204,10 +257,16 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == output assert "".join(chunks) == output
async def test_batch_completions(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list # test simple list
batch = await client.completions.create( batch = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=["Hello, my name is", "Hello, my name is"], prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
@ -217,7 +276,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test n = 2 # test n = 2
batch = await client.completions.create( batch = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=["Hello, my name is", "Hello, my name is"], prompt=["Hello, my name is", "Hello, my name is"],
n=2, n=2,
max_tokens=5, max_tokens=5,
@ -236,7 +295,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test streaming # test streaming
batch = await client.completions.create( batch = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=["Hello, my name is", "Hello, my name is"], prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,

View File

@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI):
app = fastapi.FastAPI(lifespan=lifespan) app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
@ -81,6 +92,15 @@ def parse_args():
help="The model name used in the API. If not " help="The model name used in the API. If not "
"specified, the model name will be the same as " "specified, the model name will be the same as "
"the huggingface name.") "the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=str, type=str,
default=None, default=None,
@ -217,8 +237,10 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
openai_serving_chat = OpenAIServingChat(engine, served_model, openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role, args.response_role,
args.lora_modules,
args.chat_template) args.chat_template)
openai_serving_completion = OpenAIServingCompletion(engine, served_model) openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)
# Register labels for metrics # Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model) add_global_metrics_labels(model_name=engine_args.model)

View File

@ -1,7 +1,7 @@
import time import time
import codecs import codecs
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Union from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) UsageInfo)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
logger = init_logger(__name__) logger = init_logger(__name__)
@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing):
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model: str,
response_role: str, response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None): chat_template=None):
super().__init__(engine=engine, served_model=served_model) super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template) self._load_chat_template(chat_template)
@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing):
token_ids = self._validate_prompt_and_tokenize(request, token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt) prompt=prompt)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params, result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids) request_id, token_ids,
lora_request)
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(

View File

@ -15,7 +15,7 @@ from .protocol import (
UsageInfo, UsageInfo,
) )
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
logger = init_logger(__name__) logger = init_logger(__name__)
@ -249,8 +249,13 @@ def merge_async_iterators(*iterators):
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str): def __init__(self,
super().__init__(engine=engine, served_model=served_model) engine: AsyncLLMEngine,
served_model: str,
lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest, async def create_completion(self, request: CompletionRequest,
raw_request: Request): raw_request: Request):
@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators = [] generators = []
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt) prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
self.engine.generate(None, self.engine.generate(None,
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=input_ids)) prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from vllm.logger import init_logger from vllm.logger import init_logger
@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ErrorResponse, LogProbs, ErrorResponse, LogProbs,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission) ModelPermission)
from vllm.lora.request import LoRARequest
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class LoRA:
name: str
local_path: str
class OpenAIServing: class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, served_model: str): def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules=Optional[List[LoRA]]):
self.engine = engine self.engine = engine
self.served_model = served_model self.served_model = served_model
if lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(lora_modules, start=1)
]
self.max_model_len = 0 self.max_model_len = 0
self.tokenizer = None self.tokenizer = None
@ -50,6 +71,13 @@ class OpenAIServing:
root=self.served_model, root=self.served_model,
permission=[ModelPermission()]) permission=[ModelPermission()])
] ]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model,
permission=[ModelPermission()])
for lora in self.lora_requests
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards) return ModelList(data=model_cards)
def _create_logprobs( def _create_logprobs(
@ -99,11 +127,22 @@ class OpenAIServing:
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model: if request.model == self.served_model:
return return
if request.model in [lora.lora_name for lora in self.lora_requests]:
return
return self.create_error_response( return self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model == self.served_model:
return
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],