[Bugfix][Refactor] Unify model management in frontend (#11660)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-12-31 18:21:51 -08:00 committed by GitHub
parent 0c6f998554
commit 4db72e57f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 365 additions and 307 deletions

View File

@ -4,7 +4,7 @@ import pytest
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.utils import FlexibleArgumentParser
from ...utils import VLLM_PATH

View File

@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files):
"64",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
# Enable the /v1/load_lora_adapter endpoint
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
yield remote_server
@ -67,7 +70,7 @@ async def client_for_lora_lineage(server_with_lora_modules_json):
@pytest.mark.asyncio
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = models.data
@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio
async def test_dynamic_lora_lineage(
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):
response = await client_for_lora_lineage.post("load_lora_adapter",
cast_to=str,
body={
"lora_name":
"zephyr-lora-3",
"lora_path":
zephyr_lora_files
})
# Ensure adapter loads before querying /models
assert "success" in response
models = await client_for_lora_lineage.models.list()
models = models.data
dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files
assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3"

View File

@ -8,7 +8,8 @@ from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL_NAME = "openai-community/gpt2"
@ -50,14 +51,13 @@ async def _async_serving_chat_init():
engine = MockEngine()
model_config = await engine.get_model_config()
models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
serving_completion = OpenAIServingChat(engine,
model_config,
BASE_MODEL_PATHS,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_completion
@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
BASE_MODEL_PATHS,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
BASE_MODEL_PATHS,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,

View File

@ -4,11 +4,11 @@ from unittest.mock import MagicMock
import pytest
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.lora.request import LoRARequest
MODEL_NAME = "meta-llama/Llama-2-7b"
@ -19,47 +19,45 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' removed successfully.")
async def _async_serving_engine_init():
mock_engine_client = MagicMock(spec=EngineClient)
async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config = MagicMock(spec=ModelConfig)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048
serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
BASE_MODEL_PATHS,
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_engine
prompt_adapters=None)
return serving_models
@pytest.mark.asyncio
async def test_serving_model_name():
serving_engine = await _async_serving_engine_init()
assert serving_engine._get_model_name(None) == MODEL_NAME
serving_models = await _async_serving_models_init()
assert serving_models.model_name(None) == MODEL_NAME
request = LoRARequest(lora_name="adapter",
lora_path="/path/to/adapter2",
lora_int_id=1)
assert serving_engine._get_model_name(request) == request.lora_name
assert serving_models.model_name(request) == request.lora_name
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter",
lora_path="/path/to/adapter2")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_engine.lora_requests) == 1
assert serving_engine.lora_requests[0].lora_name == "adapter"
assert len(serving_models.lora_requests) == 1
assert serving_models.lora_requests[0].lora_name == "adapter"
@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
@ -67,43 +65,43 @@ async def test_load_lora_adapter_missing_fields():
@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 1
assert len(serving_models.lora_requests) == 1
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert len(serving_engine.lora_requests) == 1
assert len(serving_models.lora_requests) == 1
@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert len(serving_engine.lora_requests) == 1
response = await serving_models.load_lora_adapter(request)
assert len(serving_models.lora_requests) == 1
request = UnloadLoraAdapterRequest(lora_name="adapter1")
response = await serving_engine.unload_lora_adapter(request)
response = await serving_models.unload_lora_adapter(request)
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 0
assert len(serving_models.lora_requests) == 0
@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_engine.unload_lora_adapter(request)
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
@ -111,9 +109,9 @@ async def test_unload_lora_adapter_missing_fields():
@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_engine.unload_lora_adapter(request)
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST

View File

@ -58,7 +58,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
@ -269,6 +271,10 @@ def base(request: Request) -> OpenAIServing:
return tokenization(request)
def models(request: Request) -> OpenAIServingModels:
return request.app.state.openai_serving_models
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
@ -336,10 +342,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = base(raw_request)
handler = models(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@router.get("/version")
@ -505,9 +511,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
handler = models(raw_request)
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
@ -518,9 +522,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
handler = models(raw_request)
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
@ -628,13 +630,18 @@ def init_app_state(
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_models = OpenAIServingModels(
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
# TODO: The chat template is now broken for lora adapters :(
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@ -646,16 +653,14 @@ def init_app_state(
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@ -663,7 +668,7 @@ def init_app_state(
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@ -671,14 +676,13 @@ def init_app_state(
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,

View File

@ -12,7 +12,7 @@ from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser

View File

@ -20,7 +20,8 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
@ -213,13 +214,17 @@ async def main(args):
request_logger = RequestLogger(max_log_len=args.max_log_len)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
)
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
base_model_paths,
openai_serving_models,
args.response_role,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
@ -228,7 +233,7 @@ async def main(args):
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
base_model_paths,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",

View File

@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
@ -42,11 +40,9 @@ class OpenAIServingChat(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
@ -57,9 +53,7 @@ class OpenAIServingChat(OpenAIServing):
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
@ -126,7 +120,7 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_name = self._get_model_name(lora_request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)

View File

@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -41,18 +39,14 @@ class OpenAIServingCompletion(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
@ -170,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators)
model_name = self._get_model_name(lora_request)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the

View File

@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
@ -46,7 +47,7 @@ class OpenAIServingEmbedding(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
@ -54,9 +55,7 @@ class OpenAIServingEmbedding(OpenAIServing):
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
models=models,
request_logger=request_logger)
self.chat_template = chat_template

View File

@ -1,7 +1,5 @@
import json
import pathlib
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
@ -28,13 +26,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission, ScoreRequest,
ErrorResponse, ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
UnloadLoraAdapterRequest)
TokenizeCompletionRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
from vllm.inputs import TokensPrompt
@ -48,30 +43,10 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid
from vllm.utils import is_list_of, make_async, random_uuid
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, ScoreRequest,
TokenizeCompletionRequest]
@ -96,10 +71,8 @@ class OpenAIServing:
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
@ -109,35 +82,7 @@ class OpenAIServing:
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.base_model_paths = base_model_paths
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self._is_model_supported(lora.base_model_name)
else self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.models = models
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
@ -150,33 +95,6 @@ class OpenAIServing:
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()])
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
def create_error_response(
self,
message: str,
@ -205,11 +123,13 @@ class OpenAIServing:
) -> Optional[ErrorResponse]:
if self._is_model_supported(request.model):
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
if request.model in [
lora.lora_name for lora in self.models.lora_requests
]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return self.create_error_response(
@ -223,10 +143,10 @@ class OpenAIServing:
None, PromptAdapterRequest]]:
if self._is_model_supported(request.model):
return None, None
for lora in self.lora_requests:
for lora in self.models.lora_requests:
if request.model == lora.lora_name:
return lora, None
for prompt_adapter in self.prompt_adapter_requests:
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
@ -588,91 +508,5 @@ class OpenAIServing:
return logprob.decoded_token
return tokenizer.decode(token_id)
async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return self.create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return self.create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret
lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
def _get_model_name(self, lora: Optional[LoRARequest]):
"""
Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora is not None:
return lora.lora_name
return self.base_model_paths[0].name
return self.models.is_base_model(model_name)

View File

@ -0,0 +1,210 @@
import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
UnloadLoraAdapterRequest)
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.
Handles the routes:
- /v1/models
- /v1/load_lora_adapter
- /v1/unload_lora_adapter
"""
def __init__(
self,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]] = None,
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
):
super().__init__()
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self.is_base_model(lora.base_model_name) else
self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
def is_base_model(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all
adapters"""
model_cards = [
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()])
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret
lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)

View File

@ -15,7 +15,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingChatRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators
@ -44,7 +45,7 @@ class OpenAIServingPooling(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
@ -52,9 +53,7 @@ class OpenAIServingPooling(OpenAIServing):
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
models=models,
request_logger=request_logger)
self.chat_template = chat_template

View File

@ -10,7 +10,8 @@ from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
ScoreResponse, ScoreResponseData,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
@ -50,15 +51,13 @@ class OpenAIServingScores(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
models=models,
request_logger=request_logger)
async def create_score(

View File

@ -15,9 +15,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -29,18 +28,15 @@ class OpenAIServingTokenization(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=None,
models=models,
request_logger=request_logger)
self.chat_template = chat_template