mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:45:02 +08:00
[Core] Support Lora lineage and base model metadata management (#6315)
This commit is contained in:
parent
9e5ec35b1f
commit
260d40b5ea
@ -159,3 +159,67 @@ Example request to unload a LoRA adapter:
|
|||||||
-d '{
|
-d '{
|
||||||
"lora_name": "sql_adapter"
|
"lora_name": "sql_adapter"
|
||||||
}'
|
}'
|
||||||
|
|
||||||
|
|
||||||
|
New format for `--lora-modules`
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
|
||||||
|
|
||||||
|
This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`.
|
||||||
|
Now, you can specify a base_model_name alongside the name and path using JSON format. For example:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'
|
||||||
|
|
||||||
|
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
|
||||||
|
|
||||||
|
|
||||||
|
Lora model lineage in model card
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:
|
||||||
|
|
||||||
|
- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter.
|
||||||
|
- The `root` field points to the artifact location of the lora adapter.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ curl http://localhost:8000/v1/models
|
||||||
|
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "meta-llama/Llama-2-7b-hf",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1715644056,
|
||||||
|
"owned_by": "vllm",
|
||||||
|
"root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/",
|
||||||
|
"parent": null,
|
||||||
|
"permission": [
|
||||||
|
{
|
||||||
|
.....
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "sql-lora",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1715644056,
|
||||||
|
"owned_by": "vllm",
|
||||||
|
"root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/",
|
||||||
|
"parent": meta-llama/Llama-2-7b-hf,
|
||||||
|
"permission": [
|
||||||
|
{
|
||||||
|
....
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|||||||
91
tests/entrypoints/openai/test_cli_args.py
Normal file
91
tests/entrypoints/openai/test_cli_args.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
|
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
LORA_MODULE = {
|
||||||
|
"name": "module2",
|
||||||
|
"path": "/path/to/module2",
|
||||||
|
"base_model_name": "llama"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoraParserAction(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Setting up argparse parser for tests
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM's remote OpenAI server.")
|
||||||
|
self.parser = make_arg_parser(parser)
|
||||||
|
|
||||||
|
def test_valid_key_value_format(self):
|
||||||
|
# Test old format: name=path
|
||||||
|
args = self.parser.parse_args([
|
||||||
|
'--lora-modules',
|
||||||
|
'module1=/path/to/module1',
|
||||||
|
])
|
||||||
|
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
|
||||||
|
self.assertEqual(args.lora_modules, expected)
|
||||||
|
|
||||||
|
def test_valid_json_format(self):
|
||||||
|
# Test valid JSON format input
|
||||||
|
args = self.parser.parse_args([
|
||||||
|
'--lora-modules',
|
||||||
|
json.dumps(LORA_MODULE),
|
||||||
|
])
|
||||||
|
expected = [
|
||||||
|
LoRAModulePath(name='module2',
|
||||||
|
path='/path/to/module2',
|
||||||
|
base_model_name='llama')
|
||||||
|
]
|
||||||
|
self.assertEqual(args.lora_modules, expected)
|
||||||
|
|
||||||
|
def test_invalid_json_format(self):
|
||||||
|
# Test invalid JSON format input, missing closing brace
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
self.parser.parse_args([
|
||||||
|
'--lora-modules',
|
||||||
|
'{"name": "module3", "path": "/path/to/module3"'
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_invalid_type_error(self):
|
||||||
|
# Test type error when values are not JSON or key=value
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
self.parser.parse_args([
|
||||||
|
'--lora-modules',
|
||||||
|
'invalid_format' # This is not JSON or key=value format
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_invalid_json_field(self):
|
||||||
|
# Test valid JSON format but missing required fields
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
self.parser.parse_args([
|
||||||
|
'--lora-modules',
|
||||||
|
'{"name": "module4"}' # Missing required 'path' field
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_empty_values(self):
|
||||||
|
# Test when no LoRA modules are provided
|
||||||
|
args = self.parser.parse_args(['--lora-modules', ''])
|
||||||
|
self.assertEqual(args.lora_modules, [])
|
||||||
|
|
||||||
|
def test_multiple_valid_inputs(self):
|
||||||
|
# Test multiple valid inputs (both old and JSON format)
|
||||||
|
args = self.parser.parse_args([
|
||||||
|
'--lora-modules',
|
||||||
|
'module1=/path/to/module1',
|
||||||
|
json.dumps(LORA_MODULE),
|
||||||
|
])
|
||||||
|
expected = [
|
||||||
|
LoRAModulePath(name='module1', path='/path/to/module1'),
|
||||||
|
LoRAModulePath(name='module2',
|
||||||
|
path='/path/to/module2',
|
||||||
|
base_model_name='llama')
|
||||||
|
]
|
||||||
|
self.assertEqual(args.lora_modules, expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
83
tests/entrypoints/openai/test_lora_lineage.py
Normal file
83
tests/entrypoints/openai/test_lora_lineage.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
# downloading lora to test lora requests
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# any model with a chat template should work here
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
||||||
|
# generation quality here
|
||||||
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def zephyr_lora_files():
|
||||||
|
return snapshot_download(repo_id=LORA_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server_with_lora_modules_json(zephyr_lora_files):
|
||||||
|
# Define the json format LoRA module configurations
|
||||||
|
lora_module_1 = {
|
||||||
|
"name": "zephyr-lora",
|
||||||
|
"path": zephyr_lora_files,
|
||||||
|
"base_model_name": MODEL_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
lora_module_2 = {
|
||||||
|
"name": "zephyr-lora2",
|
||||||
|
"path": zephyr_lora_files,
|
||||||
|
"base_model_name": MODEL_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
# lora config below
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
json.dumps(lora_module_1),
|
||||||
|
json.dumps(lora_module_2),
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
|
"--max-cpu-loras",
|
||||||
|
"2",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"64",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client_for_lora_lineage(server_with_lora_modules_json):
|
||||||
|
async with server_with_lora_modules_json.get_async_client(
|
||||||
|
) as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||||
|
zephyr_lora_files):
|
||||||
|
models = await client_for_lora_lineage.models.list()
|
||||||
|
models = models.data
|
||||||
|
served_model = models[0]
|
||||||
|
lora_models = models[1:]
|
||||||
|
assert served_model.id == MODEL_NAME
|
||||||
|
assert served_model.root == MODEL_NAME
|
||||||
|
assert served_model.parent is None
|
||||||
|
assert all(lora_model.root == zephyr_lora_files
|
||||||
|
for lora_model in lora_models)
|
||||||
|
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||||
|
assert lora_models[0].id == "zephyr-lora"
|
||||||
|
assert lora_models[1].id == "zephyr-lora2"
|
||||||
@ -51,12 +51,14 @@ async def client(server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_models(client: openai.AsyncOpenAI):
|
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
|
||||||
models = await client.models.list()
|
models = await client.models.list()
|
||||||
models = models.data
|
models = models.data
|
||||||
served_model = models[0]
|
served_model = models[0]
|
||||||
lora_models = models[1:]
|
lora_models = models[1:]
|
||||||
assert served_model.id == MODEL_NAME
|
assert served_model.id == MODEL_NAME
|
||||||
assert all(model.root == MODEL_NAME for model in models)
|
assert served_model.root == MODEL_NAME
|
||||||
|
assert all(lora_model.root == zephyr_lora_files
|
||||||
|
for lora_model in lora_models)
|
||||||
assert lora_models[0].id == "zephyr-lora"
|
assert lora_models[0].id == "zephyr-lora"
|
||||||
assert lora_models[1].id == "zephyr-lora2"
|
assert lora_models[1].id == "zephyr-lora2"
|
||||||
|
|||||||
@ -7,10 +7,12 @@ from vllm.config import MultiModalConfig
|
|||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
MODEL_NAME = "openai-community/gpt2"
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||||
|
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -37,7 +39,7 @@ async def _async_serving_chat_init():
|
|||||||
|
|
||||||
serving_completion = OpenAIServingChat(engine,
|
serving_completion = OpenAIServingChat(engine,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names=[MODEL_NAME],
|
BASE_MODEL_PATHS,
|
||||||
response_role="assistant",
|
response_role="assistant",
|
||||||
chat_template=CHAT_TEMPLATE,
|
chat_template=CHAT_TEMPLATE,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
MockModelConfig(),
|
MockModelConfig(),
|
||||||
served_model_names=[MODEL_NAME],
|
BASE_MODEL_PATHS,
|
||||||
response_role="assistant",
|
response_role="assistant",
|
||||||
chat_template=CHAT_TEMPLATE,
|
chat_template=CHAT_TEMPLATE,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
|
|||||||
@ -8,9 +8,10 @@ from vllm.engine.protocol import EngineClient
|
|||||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-2-7b"
|
MODEL_NAME = "meta-llama/Llama-2-7b"
|
||||||
|
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||||
LORA_LOADING_SUCCESS_MESSAGE = (
|
LORA_LOADING_SUCCESS_MESSAGE = (
|
||||||
"Success: LoRA adapter '{lora_name}' added successfully.")
|
"Success: LoRA adapter '{lora_name}' added successfully.")
|
||||||
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||||
@ -25,7 +26,7 @@ async def _async_serving_engine_init():
|
|||||||
|
|
||||||
serving_engine = OpenAIServing(mock_engine_client,
|
serving_engine = OpenAIServing(mock_engine_client,
|
||||||
mock_model_config,
|
mock_model_config,
|
||||||
served_model_names=[MODEL_NAME],
|
BASE_MODEL_PATHS,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
request_logger=None)
|
request_logger=None)
|
||||||
|
|||||||
@ -50,6 +50,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
OpenAIServingTokenization)
|
OpenAIServingTokenization)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -476,13 +477,18 @@ def init_app_state(
|
|||||||
else:
|
else:
|
||||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
|
base_model_paths = [
|
||||||
|
BaseModelPath(name=name, model_path=args.model)
|
||||||
|
for name in served_model_names
|
||||||
|
]
|
||||||
|
|
||||||
state.engine_client = engine_client
|
state.engine_client = engine_client
|
||||||
state.log_stats = not args.disable_log_stats
|
state.log_stats = not args.disable_log_stats
|
||||||
|
|
||||||
state.openai_serving_chat = OpenAIServingChat(
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
base_model_paths,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
prompt_adapters=args.prompt_adapters,
|
prompt_adapters=args.prompt_adapters,
|
||||||
@ -494,7 +500,7 @@ def init_app_state(
|
|||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
base_model_paths,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
prompt_adapters=args.prompt_adapters,
|
prompt_adapters=args.prompt_adapters,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
@ -503,13 +509,13 @@ def init_app_state(
|
|||||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
base_model_paths,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
)
|
)
|
||||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
base_model_paths,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=args.chat_template,
|
chat_template=args.chat_template,
|
||||||
|
|||||||
@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action):
|
|||||||
|
|
||||||
lora_list: List[LoRAModulePath] = []
|
lora_list: List[LoRAModulePath] = []
|
||||||
for item in values:
|
for item in values:
|
||||||
|
if item in [None, '']: # Skip if item is None or empty string
|
||||||
|
continue
|
||||||
|
if '=' in item and ',' not in item: # Old format: name=path
|
||||||
name, path = item.split('=')
|
name, path = item.split('=')
|
||||||
lora_list.append(LoRAModulePath(name, path))
|
lora_list.append(LoRAModulePath(name, path))
|
||||||
|
else: # Assume JSON format
|
||||||
|
try:
|
||||||
|
lora_dict = json.loads(item)
|
||||||
|
lora = LoRAModulePath(**lora_dict)
|
||||||
|
lora_list.append(lora)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parser.error(
|
||||||
|
f"Invalid JSON format for --lora-modules: {item}")
|
||||||
|
except TypeError as e:
|
||||||
|
parser.error(
|
||||||
|
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||||
|
)
|
||||||
setattr(namespace, self.dest, lora_list)
|
setattr(namespace, self.dest, lora_list)
|
||||||
|
|
||||||
|
|
||||||
@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
nargs='+',
|
nargs='+',
|
||||||
action=LoRAParserAction,
|
action=LoRAParserAction,
|
||||||
help="LoRA module configurations in the format name=path. "
|
help="LoRA module configurations in either 'name=path' format"
|
||||||
"Multiple modules can be specified.")
|
"or JSON format. "
|
||||||
|
"Example (old format): 'name=path' "
|
||||||
|
"Example (new format): "
|
||||||
|
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||||
|
"\"base_model_name\": \"id\"}'")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt-adapters",
|
"--prompt-adapters",
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
@ -196,6 +197,10 @@ async def main(args):
|
|||||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
||||||
|
|
||||||
model_config = await engine.get_model_config()
|
model_config = await engine.get_model_config()
|
||||||
|
base_model_paths = [
|
||||||
|
BaseModelPath(name=name, model_path=args.model)
|
||||||
|
for name in served_model_names
|
||||||
|
]
|
||||||
|
|
||||||
if args.disable_log_requests:
|
if args.disable_log_requests:
|
||||||
request_logger = None
|
request_logger = None
|
||||||
@ -206,7 +211,7 @@ async def main(args):
|
|||||||
openai_serving_chat = OpenAIServingChat(
|
openai_serving_chat = OpenAIServingChat(
|
||||||
engine,
|
engine,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
base_model_paths,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
@ -216,7 +221,7 @@ async def main(args):
|
|||||||
openai_serving_embedding = OpenAIServingEmbedding(
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine,
|
engine,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
base_model_paths,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
|
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
|
LoRAModulePath,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
PromptAdapterPath,
|
PromptAdapterPath,
|
||||||
TextTokensPrompt)
|
TextTokensPrompt)
|
||||||
@ -47,7 +48,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
base_model_paths: List[BaseModelPath],
|
||||||
response_role: str,
|
response_role: str,
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
@ -59,7 +60,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool_parser: Optional[str] = None):
|
tool_parser: Optional[str] = None):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=lora_modules,
|
lora_modules=lora_modules,
|
||||||
prompt_adapters=prompt_adapters,
|
prompt_adapters=prompt_adapters,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
@ -262,7 +263,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.base_model_paths[0].name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
chunk_object_type: Final = "chat.completion.chunk"
|
chunk_object_type: Final = "chat.completion.chunk"
|
||||||
first_iteration = True
|
first_iteration = True
|
||||||
@ -596,7 +597,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.base_model_paths[0].name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
final_res: Optional[RequestOutput] = None
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
ErrorResponse, UsageInfo)
|
ErrorResponse, UsageInfo)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
|
LoRAModulePath,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -45,7 +46,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
base_model_paths: List[BaseModelPath],
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||||
@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
):
|
):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=lora_modules,
|
lora_modules=lora_modules,
|
||||||
prompt_adapters=prompt_adapters,
|
prompt_adapters=prompt_adapters,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
@ -89,7 +90,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"suffix is not currently supported")
|
"suffix is not currently supported")
|
||||||
|
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.base_model_paths[0].name
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
|||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
EmbeddingResponseData,
|
EmbeddingResponseData,
|
||||||
ErrorResponse, UsageInfo)
|
ErrorResponse, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||||
from vllm.utils import merge_async_iterators, random_uuid
|
from vllm.utils import merge_async_iterators, random_uuid
|
||||||
@ -73,13 +73,13 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
base_model_paths: List[BaseModelPath],
|
||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
):
|
):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
|
|||||||
@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelPath:
|
||||||
|
name: str
|
||||||
|
model_path: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptAdapterPath:
|
class PromptAdapterPath:
|
||||||
name: str
|
name: str
|
||||||
@ -49,6 +55,7 @@ class PromptAdapterPath:
|
|||||||
class LoRAModulePath:
|
class LoRAModulePath:
|
||||||
name: str
|
name: str
|
||||||
path: str
|
path: str
|
||||||
|
base_model_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||||
@ -66,7 +73,7 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
base_model_paths: List[BaseModelPath],
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||||
@ -79,17 +86,20 @@ class OpenAIServing:
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
|
|
||||||
self.served_model_names = served_model_names
|
self.base_model_paths = base_model_paths
|
||||||
|
|
||||||
self.lora_id_counter = AtomicCounter(0)
|
self.lora_id_counter = AtomicCounter(0)
|
||||||
self.lora_requests = []
|
self.lora_requests = []
|
||||||
if lora_modules is not None:
|
if lora_modules is not None:
|
||||||
self.lora_requests = [
|
self.lora_requests = [
|
||||||
LoRARequest(
|
LoRARequest(lora_name=lora.name,
|
||||||
lora_name=lora.name,
|
|
||||||
lora_int_id=i,
|
lora_int_id=i,
|
||||||
lora_path=lora.path,
|
lora_path=lora.path,
|
||||||
) for i, lora in enumerate(lora_modules, start=1)
|
base_model_name=lora.base_model_name
|
||||||
|
if lora.base_model_name
|
||||||
|
and self._is_model_supported(lora.base_model_name)
|
||||||
|
else self.base_model_paths[0].name)
|
||||||
|
for i, lora in enumerate(lora_modules, start=1)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.prompt_adapter_requests = []
|
self.prompt_adapter_requests = []
|
||||||
@ -112,21 +122,23 @@ class OpenAIServing:
|
|||||||
async def show_available_models(self) -> ModelList:
|
async def show_available_models(self) -> ModelList:
|
||||||
"""Show available models. Right now we only have one model."""
|
"""Show available models. Right now we only have one model."""
|
||||||
model_cards = [
|
model_cards = [
|
||||||
ModelCard(id=served_model_name,
|
ModelCard(id=base_model.name,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
root=self.served_model_names[0],
|
root=base_model.model_path,
|
||||||
permission=[ModelPermission()])
|
permission=[ModelPermission()])
|
||||||
for served_model_name in self.served_model_names
|
for base_model in self.base_model_paths
|
||||||
]
|
]
|
||||||
lora_cards = [
|
lora_cards = [
|
||||||
ModelCard(id=lora.lora_name,
|
ModelCard(id=lora.lora_name,
|
||||||
root=self.served_model_names[0],
|
root=lora.local_path,
|
||||||
|
parent=lora.base_model_name if lora.base_model_name else
|
||||||
|
self.base_model_paths[0].name,
|
||||||
permission=[ModelPermission()])
|
permission=[ModelPermission()])
|
||||||
for lora in self.lora_requests
|
for lora in self.lora_requests
|
||||||
]
|
]
|
||||||
prompt_adapter_cards = [
|
prompt_adapter_cards = [
|
||||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||||
root=self.served_model_names[0],
|
root=self.base_model_paths[0].name,
|
||||||
permission=[ModelPermission()])
|
permission=[ModelPermission()])
|
||||||
for prompt_adapter in self.prompt_adapter_requests
|
for prompt_adapter in self.prompt_adapter_requests
|
||||||
]
|
]
|
||||||
@ -169,7 +181,7 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
) -> Optional[ErrorResponse]:
|
) -> Optional[ErrorResponse]:
|
||||||
if request.model in self.served_model_names:
|
if self._is_model_supported(request.model):
|
||||||
return None
|
return None
|
||||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||||
return None
|
return None
|
||||||
@ -187,7 +199,7 @@ class OpenAIServing:
|
|||||||
self, request: AnyRequest
|
self, request: AnyRequest
|
||||||
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
||||||
None, PromptAdapterRequest]]:
|
None, PromptAdapterRequest]]:
|
||||||
if request.model in self.served_model_names:
|
if self._is_model_supported(request.model):
|
||||||
return None, None
|
return None, None
|
||||||
for lora in self.lora_requests:
|
for lora in self.lora_requests:
|
||||||
if request.model == lora.lora_name:
|
if request.model == lora.lora_name:
|
||||||
@ -480,3 +492,6 @@ class OpenAIServing:
|
|||||||
if lora_request.lora_name != lora_name
|
if lora_request.lora_name != lora_name
|
||||||
]
|
]
|
||||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||||
|
|
||||||
|
def _is_model_supported(self, model_name):
|
||||||
|
return any(model.name == model_name for model in self.base_model_paths)
|
||||||
|
|||||||
@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
|||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse)
|
TokenizeResponse)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
|
LoRAModulePath,
|
||||||
OpenAIServing)
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
@ -31,7 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
base_model_paths: List[BaseModelPath],
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
@ -39,7 +40,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
):
|
):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=lora_modules,
|
lora_modules=lora_modules,
|
||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class LoRARequest(
|
|||||||
lora_path: str = ""
|
lora_path: str = ""
|
||||||
lora_local_path: Optional[str] = msgspec.field(default=None)
|
lora_local_path: Optional[str] = msgspec.field(default=None)
|
||||||
long_lora_max_len: Optional[int] = None
|
long_lora_max_len: Optional[int] = None
|
||||||
|
base_model_name: Optional[str] = msgspec.field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if 'lora_local_path' in self.__struct_fields__:
|
if 'lora_local_path' in self.__struct_fields__:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user