mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:35:27 +08:00
[Core] Support Lora lineage and base model metadata management (#6315)
This commit is contained in:
parent
9e5ec35b1f
commit
260d40b5ea
@ -159,3 +159,67 @@ Example request to unload a LoRA adapter:
|
||||
-d '{
|
||||
"lora_name": "sql_adapter"
|
||||
}'
|
||||
|
||||
|
||||
New format for `--lora-modules`
|
||||
-------------------------------
|
||||
|
||||
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
|
||||
|
||||
This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`.
|
||||
Now, you can specify a base_model_name alongside the name and path using JSON format. For example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'
|
||||
|
||||
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
|
||||
|
||||
|
||||
Lora model lineage in model card
|
||||
--------------------------------
|
||||
|
||||
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:
|
||||
|
||||
- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter.
|
||||
- The `root` field points to the artifact location of the lora adapter.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ curl http://localhost:8000/v1/models
|
||||
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "meta-llama/Llama-2-7b-hf",
|
||||
"object": "model",
|
||||
"created": 1715644056,
|
||||
"owned_by": "vllm",
|
||||
"root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/",
|
||||
"parent": null,
|
||||
"permission": [
|
||||
{
|
||||
.....
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "sql-lora",
|
||||
"object": "model",
|
||||
"created": 1715644056,
|
||||
"owned_by": "vllm",
|
||||
"root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/",
|
||||
"parent": meta-llama/Llama-2-7b-hf,
|
||||
"permission": [
|
||||
{
|
||||
....
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
91
tests/entrypoints/openai/test_cli_args.py
Normal file
91
tests/entrypoints/openai/test_cli_args.py
Normal file
@ -0,0 +1,91 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
LORA_MODULE = {
|
||||
"name": "module2",
|
||||
"path": "/path/to/module2",
|
||||
"base_model_name": "llama"
|
||||
}
|
||||
|
||||
|
||||
class TestLoraParserAction(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Setting up argparse parser for tests
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
self.parser = make_arg_parser(parser)
|
||||
|
||||
def test_valid_key_value_format(self):
|
||||
# Test old format: name=path
|
||||
args = self.parser.parse_args([
|
||||
'--lora-modules',
|
||||
'module1=/path/to/module1',
|
||||
])
|
||||
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
|
||||
self.assertEqual(args.lora_modules, expected)
|
||||
|
||||
def test_valid_json_format(self):
|
||||
# Test valid JSON format input
|
||||
args = self.parser.parse_args([
|
||||
'--lora-modules',
|
||||
json.dumps(LORA_MODULE),
|
||||
])
|
||||
expected = [
|
||||
LoRAModulePath(name='module2',
|
||||
path='/path/to/module2',
|
||||
base_model_name='llama')
|
||||
]
|
||||
self.assertEqual(args.lora_modules, expected)
|
||||
|
||||
def test_invalid_json_format(self):
|
||||
# Test invalid JSON format input, missing closing brace
|
||||
with self.assertRaises(SystemExit):
|
||||
self.parser.parse_args([
|
||||
'--lora-modules',
|
||||
'{"name": "module3", "path": "/path/to/module3"'
|
||||
])
|
||||
|
||||
def test_invalid_type_error(self):
|
||||
# Test type error when values are not JSON or key=value
|
||||
with self.assertRaises(SystemExit):
|
||||
self.parser.parse_args([
|
||||
'--lora-modules',
|
||||
'invalid_format' # This is not JSON or key=value format
|
||||
])
|
||||
|
||||
def test_invalid_json_field(self):
|
||||
# Test valid JSON format but missing required fields
|
||||
with self.assertRaises(SystemExit):
|
||||
self.parser.parse_args([
|
||||
'--lora-modules',
|
||||
'{"name": "module4"}' # Missing required 'path' field
|
||||
])
|
||||
|
||||
def test_empty_values(self):
|
||||
# Test when no LoRA modules are provided
|
||||
args = self.parser.parse_args(['--lora-modules', ''])
|
||||
self.assertEqual(args.lora_modules, [])
|
||||
|
||||
def test_multiple_valid_inputs(self):
|
||||
# Test multiple valid inputs (both old and JSON format)
|
||||
args = self.parser.parse_args([
|
||||
'--lora-modules',
|
||||
'module1=/path/to/module1',
|
||||
json.dumps(LORA_MODULE),
|
||||
])
|
||||
expected = [
|
||||
LoRAModulePath(name='module1', path='/path/to/module1'),
|
||||
LoRAModulePath(name='module2',
|
||||
path='/path/to/module2',
|
||||
base_model_name='llama')
|
||||
]
|
||||
self.assertEqual(args.lora_modules, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
83
tests/entrypoints/openai/test_lora_lineage.py
Normal file
83
tests/entrypoints/openai/test_lora_lineage.py
Normal file
@ -0,0 +1,83 @@
|
||||
import json
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
||||
# generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_with_lora_modules_json(zephyr_lora_files):
|
||||
# Define the json format LoRA module configurations
|
||||
lora_module_1 = {
|
||||
"name": "zephyr-lora",
|
||||
"path": zephyr_lora_files,
|
||||
"base_model_name": MODEL_NAME
|
||||
}
|
||||
|
||||
lora_module_2 = {
|
||||
"name": "zephyr-lora2",
|
||||
"path": zephyr_lora_files,
|
||||
"base_model_name": MODEL_NAME
|
||||
}
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
json.dumps(lora_module_1),
|
||||
json.dumps(lora_module_2),
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
"--max-num-seqs",
|
||||
"64",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_for_lora_lineage(server_with_lora_modules_json):
|
||||
async with server_with_lora_modules_json.get_async_client(
|
||||
) as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||
zephyr_lora_files):
|
||||
models = await client_for_lora_lineage.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
lora_models = models[1:]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert served_model.root == MODEL_NAME
|
||||
assert served_model.parent is None
|
||||
assert all(lora_model.root == zephyr_lora_files
|
||||
for lora_model in lora_models)
|
||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
@ -51,12 +51,14 @@ async def client(server):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_models(client: openai.AsyncOpenAI):
|
||||
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
lora_models = models[1:]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert all(model.root == MODEL_NAME for model in models)
|
||||
assert served_model.root == MODEL_NAME
|
||||
assert all(lora_model.root == zephyr_lora_files
|
||||
for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
||||
@ -7,10 +7,12 @@ from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -37,7 +39,7 @@ async def _async_serving_chat_init():
|
||||
|
||||
serving_completion = OpenAIServingChat(engine,
|
||||
model_config,
|
||||
served_model_names=[MODEL_NAME],
|
||||
BASE_MODEL_PATHS,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
lora_modules=None,
|
||||
@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
MockModelConfig(),
|
||||
served_model_names=[MODEL_NAME],
|
||||
BASE_MODEL_PATHS,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
lora_modules=None,
|
||||
|
||||
@ -8,9 +8,10 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-2-7b"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
LORA_LOADING_SUCCESS_MESSAGE = (
|
||||
"Success: LoRA adapter '{lora_name}' added successfully.")
|
||||
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||
@ -25,7 +26,7 @@ async def _async_serving_engine_init():
|
||||
|
||||
serving_engine = OpenAIServing(mock_engine_client,
|
||||
mock_model_config,
|
||||
served_model_names=[MODEL_NAME],
|
||||
BASE_MODEL_PATHS,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
|
||||
@ -50,6 +50,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.logger import init_logger
|
||||
@ -476,13 +477,18 @@ def init_app_state(
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
@ -494,7 +500,7 @@ def init_app_state(
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
@ -503,13 +509,13 @@ def init_app_state(
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
|
||||
@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action):
|
||||
|
||||
lora_list: List[LoRAModulePath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
if item in [None, '']: # Skip if item is None or empty string
|
||||
continue
|
||||
if '=' in item and ',' not in item: # Old format: name=path
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in the format name=path. "
|
||||
"Multiple modules can be specified.")
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): 'name=path' "
|
||||
"Example (new format): "
|
||||
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||
"\"base_model_name\": \"id\"}'")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@ -196,6 +197,10 @@ async def main(args):
|
||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
||||
|
||||
model_config = await engine.get_model_config()
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
@ -206,7 +211,7 @@ async def main(args):
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
@ -216,7 +221,7 @@ async def main(args):
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
|
||||
@ -23,7 +23,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
@ -47,7 +48,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
def __init__(self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
base_model_paths: List[BaseModelPath],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
@ -59,7 +60,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_parser: Optional[str] = None):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
@ -262,7 +263,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.served_model_names[0]
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
@ -596,7 +597,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
|
||||
@ -20,7 +20,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
@ -45,7 +46,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
@ -89,7 +90,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
model_name = self.base_model_paths[0].name
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
@ -73,13 +73,13 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
@ -49,6 +55,7 @@ class PromptAdapterPath:
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||
@ -66,7 +73,7 @@ class OpenAIServing:
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
@ -79,17 +86,20 @@ class OpenAIServing:
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.served_model_names = served_model_names
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
self.lora_requests = []
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(
|
||||
lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_path=lora.path,
|
||||
) for i, lora in enumerate(lora_modules, start=1)
|
||||
LoRARequest(lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_path=lora.path,
|
||||
base_model_name=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
and self._is_model_supported(lora.base_model_name)
|
||||
else self.base_model_paths[0].name)
|
||||
for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
@ -112,21 +122,23 @@ class OpenAIServing:
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
ModelCard(id=served_model_name,
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=self.served_model_names[0],
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for served_model_name in self.served_model_names
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=self.served_model_names[0],
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.served_model_names[0],
|
||||
root=self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
@ -169,7 +181,7 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
if request.model in self.served_model_names:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return None
|
||||
@ -187,7 +199,7 @@ class OpenAIServing:
|
||||
self, request: AnyRequest
|
||||
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
if request.model in self.served_model_names:
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
@ -480,3 +492,6 @@ class OpenAIServing:
|
||||
if lora_request.lora_name != lora_name
|
||||
]
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
def _is_model_supported(self, model_name):
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
@ -31,7 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -39,7 +40,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
@ -28,6 +28,7 @@ class LoRARequest(
|
||||
lora_path: str = ""
|
||||
lora_local_path: Optional[str] = msgspec.field(default=None)
|
||||
long_lora_max_len: Optional[int] = None
|
||||
base_model_name: Optional[str] = msgspec.field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
if 'lora_local_path' in self.__struct_fields__:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user