[Core] Support Lora lineage and base model metadata management (#6315)

This commit is contained in:
Jiaxin Shan 2024-09-19 23:20:56 -07:00 committed by GitHub
parent 9e5ec35b1f
commit 260d40b5ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 337 additions and 45 deletions

View File

@ -159,3 +159,67 @@ Example request to unload a LoRA adapter:
-d '{
"lora_name": "sql_adapter"
}'
New format for `--lora-modules`
-------------------------------
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
.. code-block:: bash
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`.
Now, you can specify a base_model_name alongside the name and path using JSON format. For example:
.. code-block:: bash
--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
Lora model lineage in model card
--------------------------------
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:
- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter.
- The `root` field points to the artifact location of the lora adapter.
.. code-block:: bash
$ curl http://localhost:8000/v1/models
{
"object": "list",
"data": [
{
"id": "meta-llama/Llama-2-7b-hf",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/",
"parent": null,
"permission": [
{
.....
}
]
},
{
"id": "sql-lora",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/",
"parent": meta-llama/Llama-2-7b-hf,
"permission": [
{
....
}
]
}
]
}

View File

@ -0,0 +1,91 @@
import json
import unittest
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser
LORA_MODULE = {
"name": "module2",
"path": "/path/to/module2",
"base_model_name": "llama"
}
class TestLoraParserAction(unittest.TestCase):
def setUp(self):
# Setting up argparse parser for tests
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
self.parser = make_arg_parser(parser)
def test_valid_key_value_format(self):
# Test old format: name=path
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
])
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
self.assertEqual(args.lora_modules, expected)
def test_valid_json_format(self):
# Test valid JSON format input
args = self.parser.parse_args([
'--lora-modules',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)
def test_invalid_json_format(self):
# Test invalid JSON format input, missing closing brace
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module3", "path": "/path/to/module3"'
])
def test_invalid_type_error(self):
# Test type error when values are not JSON or key=value
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'invalid_format' # This is not JSON or key=value format
])
def test_invalid_json_field(self):
# Test valid JSON format but missing required fields
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module4"}' # Missing required 'path' field
])
def test_empty_values(self):
# Test when no LoRA modules are provided
args = self.parser.parse_args(['--lora-modules', ''])
self.assertEqual(args.lora_modules, [])
def test_multiple_valid_inputs(self):
# Test multiple valid inputs (both old and JSON format)
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module1', path='/path/to/module1'),
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,83 @@
import json
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def server_with_lora_modules_json(zephyr_lora_files):
# Define the json format LoRA module configurations
lora_module_1 = {
"name": "zephyr-lora",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}
lora_module_2 = {
"name": "zephyr-lora2",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
json.dumps(lora_module_1),
json.dumps(lora_module_2),
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"64",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_for_lora_lineage(server_with_lora_modules_json):
async with server_with_lora_modules_json.get_async_client(
) as async_client:
yield async_client
@pytest.mark.asyncio
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME
assert served_model.parent is None
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"

View File

@ -51,12 +51,14 @@ async def client(server):
@pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI):
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
assert served_model.root == MODEL_NAME
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"

View File

@ -7,10 +7,12 @@ from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
@dataclass
@ -37,7 +39,7 @@ async def _async_serving_chat_init():
serving_completion = OpenAIServingChat(engine,
model_config,
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens():
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,

View File

@ -8,9 +8,10 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
MODEL_NAME = "meta-llama/Llama-2-7b"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
LORA_LOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
@ -25,7 +26,7 @@ async def _async_serving_engine_init():
serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
lora_modules=None,
prompt_adapters=None,
request_logger=None)

View File

@ -50,6 +50,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
@ -476,13 +477,18 @@ def init_app_state(
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
served_model_names,
base_model_paths,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
@ -494,7 +500,7 @@ def init_app_state(
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
@ -503,13 +509,13 @@ def init_app_state(
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
served_model_names,
base_model_paths,
request_logger=request_logger,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,

View File

@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action):
lora_list: List[LoRAModulePath] = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
if item in [None, '']: # Skip if item is None or empty string
continue
if '=' in item and ',' not in item: # Old format: name=path
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(
f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
help="LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{\"name\": \"name\", \"local_path\": \"path\", "
"\"base_model_name\": \"id\"}'")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,

View File

@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
@ -196,6 +197,10 @@ async def main(args):
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
model_config = await engine.get_model_config()
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
if args.disable_log_requests:
request_logger = None
@ -206,7 +211,7 @@ async def main(args):
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
base_model_paths,
args.response_role,
lora_modules=None,
prompt_adapters=None,
@ -216,7 +221,7 @@ async def main(args):
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
base_model_paths,
request_logger=request_logger,
)

View File

@ -23,7 +23,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
@ -47,7 +48,7 @@ class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
@ -59,7 +60,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser: Optional[str] = None):
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger,
@ -262,7 +263,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
model_name = self.base_model_paths[0].name
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
@ -596,7 +597,7 @@ class OpenAIServingChat(OpenAIServing):
tokenizer: AnyTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
model_name = self.base_model_paths[0].name
created_time = int(time.time())
final_res: Optional[RequestOutput] = None

View File

@ -20,7 +20,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionStreamResponse,
ErrorResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
@ -45,7 +46,7 @@ class OpenAIServingCompletion(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
):
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger,
@ -89,7 +90,7 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response(
"suffix is not currently supported")
model_name = self.served_model_names[0]
model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())

View File

@ -14,7 +14,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
@ -73,13 +73,13 @@ class OpenAIServingEmbedding(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)

View File

@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
@ -49,6 +55,7 @@ class PromptAdapterPath:
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
@ -66,7 +73,7 @@ class OpenAIServing:
self,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
@ -79,17 +86,20 @@ class OpenAIServing:
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.served_model_names = served_model_names
self.base_model_paths = base_model_paths
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
) for i, lora in enumerate(lora_modules, start=1)
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self._is_model_supported(lora.base_model_name)
else self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
@ -112,21 +122,23 @@ class OpenAIServing:
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model_name,
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=self.served_model_names[0],
root=base_model.model_path,
permission=[ModelPermission()])
for served_model_name in self.served_model_names
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model_names[0],
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
@ -169,7 +181,7 @@ class OpenAIServing:
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
if self._is_model_supported(request.model):
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
@ -187,7 +199,7 @@ class OpenAIServing:
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if request.model in self.served_model_names:
if self._is_model_supported(request.model):
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
@ -480,3 +492,6 @@ class OpenAIServing:
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)

View File

@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
@ -31,7 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
@ -39,7 +40,7 @@ class OpenAIServingTokenization(OpenAIServing):
):
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)

View File

@ -28,6 +28,7 @@ class LoRARequest(
lora_path: str = ""
lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = msgspec.field(default=None)
def __post_init__(self):
if 'lora_local_path' in self.__struct_fields__: