mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:55:01 +08:00
[Bugfix][Refactor] Unify model management in frontend (#11660)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
0c6f998554
commit
4db72e57f6
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
from ...utils import VLLM_PATH
|
||||
|
||||
@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files):
|
||||
"64",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
# Enable the /v1/load_lora_adapter endpoint
|
||||
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@ -67,7 +70,7 @@ async def client_for_lora_lineage(server_with_lora_modules_json):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||
zephyr_lora_files):
|
||||
models = await client_for_lora_lineage.models.list()
|
||||
models = models.data
|
||||
@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_lineage(
|
||||
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):
|
||||
|
||||
response = await client_for_lora_lineage.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name":
|
||||
"zephyr-lora-3",
|
||||
"lora_path":
|
||||
zephyr_lora_files
|
||||
})
|
||||
# Ensure adapter loads before querying /models
|
||||
assert "success" in response
|
||||
|
||||
models = await client_for_lora_lineage.models.list()
|
||||
models = models.data
|
||||
dynamic_lora_model = models[-1]
|
||||
assert dynamic_lora_model.root == zephyr_lora_files
|
||||
assert dynamic_lora_model.parent == MODEL_NAME
|
||||
assert dynamic_lora_model.id == "zephyr-lora-3"
|
||||
|
||||
@ -8,7 +8,8 @@ from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@ -50,14 +51,13 @@ async def _async_serving_chat_init():
|
||||
engine = MockEngine()
|
||||
model_config = await engine.get_model_config()
|
||||
|
||||
models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
|
||||
serving_completion = OpenAIServingChat(engine,
|
||||
model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
return serving_completion
|
||||
|
||||
@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=MockModelConfig())
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
MockModelConfig(),
|
||||
BASE_MODEL_PATHS,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
|
||||
@ -4,11 +4,11 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-2-7b"
|
||||
@ -19,47 +19,45 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||
"Success: LoRA adapter '{lora_name}' removed successfully.")
|
||||
|
||||
|
||||
async def _async_serving_engine_init():
|
||||
mock_engine_client = MagicMock(spec=EngineClient)
|
||||
async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
mock_model_config = MagicMock(spec=ModelConfig)
|
||||
# Set the max_model_len attribute to avoid missing attribute
|
||||
mock_model_config.max_model_len = 2048
|
||||
|
||||
serving_engine = OpenAIServing(mock_engine_client,
|
||||
mock_model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
return serving_engine
|
||||
prompt_adapters=None)
|
||||
|
||||
return serving_models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_model_name():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
assert serving_engine._get_model_name(None) == MODEL_NAME
|
||||
serving_models = await _async_serving_models_init()
|
||||
assert serving_models.model_name(None) == MODEL_NAME
|
||||
request = LoRARequest(lora_name="adapter",
|
||||
lora_path="/path/to/adapter2",
|
||||
lora_int_id=1)
|
||||
assert serving_engine._get_model_name(request) == request.lora_name
|
||||
assert serving_models.model_name(request) == request.lora_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_success():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter",
|
||||
lora_path="/path/to/adapter2")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
assert serving_engine.lora_requests[0].lora_name == "adapter"
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
assert serving_models.lora_requests[0].lora_name == "adapter"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_missing_fields():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
@ -67,43 +65,43 @@ async def test_load_lora_adapter_missing_fields():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_duplicate():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
|
||||
lora_name='adapter1')
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_success():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = UnloadLoraAdapterRequest(lora_name="adapter1")
|
||||
response = await serving_engine.unload_lora_adapter(request)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
|
||||
lora_name='adapter1')
|
||||
assert len(serving_engine.lora_requests) == 0
|
||||
assert len(serving_models.lora_requests) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_missing_fields():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
|
||||
response = await serving_engine.unload_lora_adapter(request)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
@ -111,9 +109,9 @@ async def test_unload_lora_adapter_missing_fields():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_not_found():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
||||
response = await serving_engine.unload_lora_adapter(request)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
@ -58,7 +58,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
@ -269,6 +271,10 @@ def base(request: Request) -> OpenAIServing:
|
||||
return tokenization(request)
|
||||
|
||||
|
||||
def models(request: Request) -> OpenAIServingModels:
|
||||
return request.app.state.openai_serving_models
|
||||
|
||||
|
||||
def chat(request: Request) -> Optional[OpenAIServingChat]:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
@ -336,10 +342,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
handler = base(raw_request)
|
||||
handler = models(raw_request)
|
||||
|
||||
models = await handler.show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
models_ = await handler.show_available_models()
|
||||
return JSONResponse(content=models_.model_dump())
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
@ -505,9 +511,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
handler = models(raw_request)
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
@ -518,9 +522,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
handler = models(raw_request)
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
@ -628,13 +630,18 @@ def init_app_state(
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
# TODO: The chat template is now broken for lora adapters :(
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
@ -646,16 +653,14 @@ def init_app_state(
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_pooling = OpenAIServingPooling(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
@ -663,7 +668,7 @@ def init_app_state(
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
@ -671,14 +676,13 @@ def init_app_state(
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import List, Optional, Sequence, Union, get_args
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -20,7 +20,8 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@ -213,13 +214,17 @@ async def main(args):
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
)
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
openai_serving_models,
|
||||
args.response_role,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
@ -228,7 +233,7 @@ async def main(args):
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
|
||||
@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
@ -42,11 +40,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
models: OpenAIServingModels,
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
@ -57,9 +53,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
@ -126,7 +120,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_name = self._get_model_name(lora_request)
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
|
||||
@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
@ -41,18 +39,14 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||
@ -170,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self._get_model_name(lora_request)
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
|
||||
@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
@ -46,7 +47,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
@ -54,9 +55,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.chat_template = chat_template
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import json
|
||||
import pathlib
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||
Optional, Sequence, Tuple, TypedDict, Union)
|
||||
@ -28,13 +26,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission, ScoreRequest,
|
||||
ErrorResponse, ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
TokenizeCompletionRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
# yapf: enable
|
||||
from vllm.inputs import TokensPrompt
|
||||
@ -48,30 +43,10 @@ from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid
|
||||
from vllm.utils import is_list_of, make_async, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingCompletionRequest, ScoreRequest,
|
||||
TokenizeCompletionRequest]
|
||||
@ -96,10 +71,8 @@ class OpenAIServing:
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
@ -109,35 +82,7 @@ class OpenAIServing:
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
self.lora_requests = []
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_path=lora.path,
|
||||
base_model_name=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
and self._is_model_supported(lora.base_model_name)
|
||||
else self.base_model_paths[0].name)
|
||||
for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with pathlib.Path(prompt_adapter.local_path,
|
||||
"adapter_config.json").open() as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
self.models = models
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
@ -150,33 +95,6 @@ class OpenAIServing:
|
||||
self._tokenize_prompt_input_or_inputs,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
@ -205,11 +123,13 @@ class OpenAIServing:
|
||||
) -> Optional[ErrorResponse]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
if request.model in [
|
||||
lora.lora_name for lora in self.models.lora_requests
|
||||
]:
|
||||
return None
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
for prompt_adapter in self.models.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
@ -223,10 +143,10 @@ class OpenAIServing:
|
||||
None, PromptAdapterRequest]]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
for lora in self.lora_requests:
|
||||
for lora in self.models.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora, None
|
||||
for prompt_adapter in self.prompt_adapter_requests:
|
||||
for prompt_adapter in self.models.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
@ -588,91 +508,5 @@ class OpenAIServing:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return self.create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return self.create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been"
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if either 'lora_name' or 'lora_int_id' is provided
|
||||
if not request.lora_name and not request.lora_int_id:
|
||||
return self.create_error_response(
|
||||
message=
|
||||
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if not any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return self.create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name, lora_path = request.lora_name, request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
self.lora_requests.append(
|
||||
LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path))
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(request
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name = request.lora_name
|
||||
self.lora_requests = [
|
||||
lora_request for lora_request in self.lora_requests
|
||||
if lora_request.lora_name != lora_name
|
||||
]
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
def _is_model_supported(self, model_name):
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def _get_model_name(self, lora: Optional[LoRARequest]):
|
||||
"""
|
||||
Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora is not None:
|
||||
return lora.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
210
vllm/entrypoints/openai/serving_models.py
Normal file
210
vllm/entrypoints/openai/serving_models.py
Normal file
@ -0,0 +1,210 @@
|
||||
import json
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIServingModels:
|
||||
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||
|
||||
Handles the routes:
|
||||
- /v1/models
|
||||
- /v1/load_lora_adapter
|
||||
- /v1/unload_lora_adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
self.lora_requests = []
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_path=lora.path,
|
||||
base_model_name=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
and self.is_base_model(lora.base_model_name) else
|
||||
self.base_model_paths[0].name)
|
||||
for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with pathlib.Path(prompt_adapter.local_path,
|
||||
"adapter_config.json").open() as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
def is_base_model(self, model_name):
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
|
||||
"""Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora_request is not None:
|
||||
return lora_request.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all
|
||||
adapters"""
|
||||
model_cards = [
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name, lora_path = request.lora_name, request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
self.lora_requests.append(
|
||||
LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path))
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(request
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name = request.lora_name
|
||||
self.lora_requests = [
|
||||
lora_request for lora_request in self.lora_requests
|
||||
if lora_request.lora_name != lora_name
|
||||
]
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been"
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if either 'lora_name' or 'lora_int_id' is provided
|
||||
if not request.lora_name and not request.lora_int_id:
|
||||
return create_error_response(
|
||||
message=
|
||||
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if not any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
@ -15,7 +15,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.utils import merge_async_iterators
|
||||
@ -44,7 +45,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
@ -52,9 +53,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.chat_template = chat_template
|
||||
|
||||
@ -10,7 +10,8 @@ from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||
ScoreResponse, ScoreResponseData,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
@ -50,15 +51,13 @@ class OpenAIServingScores(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def create_score(
|
||||
|
||||
@ -15,9 +15,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -29,18 +28,15 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=None,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.chat_template = chat_template
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user