mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:05:48 +08:00
[Bugfix][Refactor] Unify model management in frontend (#11660)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
0c6f998554
commit
4db72e57f6
@ -4,7 +4,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
validate_parsed_serve_args)
|
validate_parsed_serve_args)
|
||||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
from ...utils import VLLM_PATH
|
from ...utils import VLLM_PATH
|
||||||
|
|||||||
@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files):
|
|||||||
"64",
|
"64",
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
# Enable the /v1/load_lora_adapter endpoint
|
||||||
|
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +70,7 @@ async def client_for_lora_lineage(server_with_lora_modules_json):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||||
zephyr_lora_files):
|
zephyr_lora_files):
|
||||||
models = await client_for_lora_lineage.models.list()
|
models = await client_for_lora_lineage.models.list()
|
||||||
models = models.data
|
models = models.data
|
||||||
@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
|||||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||||
assert lora_models[0].id == "zephyr-lora"
|
assert lora_models[0].id == "zephyr-lora"
|
||||||
assert lora_models[1].id == "zephyr-lora2"
|
assert lora_models[1].id == "zephyr-lora2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_lora_lineage(
|
||||||
|
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):
|
||||||
|
|
||||||
|
response = await client_for_lora_lineage.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name":
|
||||||
|
"zephyr-lora-3",
|
||||||
|
"lora_path":
|
||||||
|
zephyr_lora_files
|
||||||
|
})
|
||||||
|
# Ensure adapter loads before querying /models
|
||||||
|
assert "success" in response
|
||||||
|
|
||||||
|
models = await client_for_lora_lineage.models.list()
|
||||||
|
models = models.data
|
||||||
|
dynamic_lora_model = models[-1]
|
||||||
|
assert dynamic_lora_model.root == zephyr_lora_files
|
||||||
|
assert dynamic_lora_model.parent == MODEL_NAME
|
||||||
|
assert dynamic_lora_model.id == "zephyr-lora-3"
|
||||||
|
|||||||
@ -8,7 +8,8 @@ from vllm.config import MultiModalConfig
|
|||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
|
OpenAIServingModels)
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
MODEL_NAME = "openai-community/gpt2"
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
@ -50,14 +51,13 @@ async def _async_serving_chat_init():
|
|||||||
engine = MockEngine()
|
engine = MockEngine()
|
||||||
model_config = await engine.get_model_config()
|
model_config = await engine.get_model_config()
|
||||||
|
|
||||||
|
models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
|
||||||
serving_completion = OpenAIServingChat(engine,
|
serving_completion = OpenAIServingChat(engine,
|
||||||
model_config,
|
model_config,
|
||||||
BASE_MODEL_PATHS,
|
models,
|
||||||
response_role="assistant",
|
response_role="assistant",
|
||||||
chat_template=CHAT_TEMPLATE,
|
chat_template=CHAT_TEMPLATE,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
lora_modules=None,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=None)
|
request_logger=None)
|
||||||
return serving_completion
|
return serving_completion
|
||||||
|
|
||||||
@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
|
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
model_config=MockModelConfig())
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
MockModelConfig(),
|
MockModelConfig(),
|
||||||
BASE_MODEL_PATHS,
|
models,
|
||||||
response_role="assistant",
|
response_role="assistant",
|
||||||
chat_template=CHAT_TEMPLATE,
|
chat_template=CHAT_TEMPLATE,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
lora_modules=None,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=None)
|
request_logger=None)
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config():
|
|||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
# Initialize the serving chat
|
# Initialize the serving chat
|
||||||
|
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
model_config=mock_model_config)
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
mock_model_config,
|
mock_model_config,
|
||||||
BASE_MODEL_PATHS,
|
models,
|
||||||
response_role="assistant",
|
response_role="assistant",
|
||||||
chat_template=CHAT_TEMPLATE,
|
chat_template=CHAT_TEMPLATE,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
lora_modules=None,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=None)
|
request_logger=None)
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
|
|||||||
@ -4,11 +4,11 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
|
||||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
|
OpenAIServingModels)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-2-7b"
|
MODEL_NAME = "meta-llama/Llama-2-7b"
|
||||||
@ -19,47 +19,45 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
|
|||||||
"Success: LoRA adapter '{lora_name}' removed successfully.")
|
"Success: LoRA adapter '{lora_name}' removed successfully.")
|
||||||
|
|
||||||
|
|
||||||
async def _async_serving_engine_init():
|
async def _async_serving_models_init() -> OpenAIServingModels:
|
||||||
mock_engine_client = MagicMock(spec=EngineClient)
|
|
||||||
mock_model_config = MagicMock(spec=ModelConfig)
|
mock_model_config = MagicMock(spec=ModelConfig)
|
||||||
# Set the max_model_len attribute to avoid missing attribute
|
# Set the max_model_len attribute to avoid missing attribute
|
||||||
mock_model_config.max_model_len = 2048
|
mock_model_config.max_model_len = 2048
|
||||||
|
|
||||||
serving_engine = OpenAIServing(mock_engine_client,
|
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||||
mock_model_config,
|
model_config=mock_model_config,
|
||||||
BASE_MODEL_PATHS,
|
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
prompt_adapters=None,
|
prompt_adapters=None)
|
||||||
request_logger=None)
|
|
||||||
return serving_engine
|
return serving_models
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_serving_model_name():
|
async def test_serving_model_name():
|
||||||
serving_engine = await _async_serving_engine_init()
|
serving_models = await _async_serving_models_init()
|
||||||
assert serving_engine._get_model_name(None) == MODEL_NAME
|
assert serving_models.model_name(None) == MODEL_NAME
|
||||||
request = LoRARequest(lora_name="adapter",
|
request = LoRARequest(lora_name="adapter",
|
||||||
lora_path="/path/to/adapter2",
|
lora_path="/path/to/adapter2",
|
||||||
lora_int_id=1)
|
lora_int_id=1)
|
||||||
assert serving_engine._get_model_name(request) == request.lora_name
|
assert serving_models.model_name(request) == request.lora_name
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_load_lora_adapter_success():
|
async def test_load_lora_adapter_success():
|
||||||
serving_engine = await _async_serving_engine_init()
|
serving_models = await _async_serving_models_init()
|
||||||
request = LoadLoraAdapterRequest(lora_name="adapter",
|
request = LoadLoraAdapterRequest(lora_name="adapter",
|
||||||
lora_path="/path/to/adapter2")
|
lora_path="/path/to/adapter2")
|
||||||
response = await serving_engine.load_lora_adapter(request)
|
response = await serving_models.load_lora_adapter(request)
|
||||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
|
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
|
||||||
assert len(serving_engine.lora_requests) == 1
|
assert len(serving_models.lora_requests) == 1
|
||||||
assert serving_engine.lora_requests[0].lora_name == "adapter"
|
assert serving_models.lora_requests[0].lora_name == "adapter"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_load_lora_adapter_missing_fields():
|
async def test_load_lora_adapter_missing_fields():
|
||||||
serving_engine = await _async_serving_engine_init()
|
serving_models = await _async_serving_models_init()
|
||||||
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
|
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
|
||||||
response = await serving_engine.load_lora_adapter(request)
|
response = await serving_models.load_lora_adapter(request)
|
||||||
assert isinstance(response, ErrorResponse)
|
assert isinstance(response, ErrorResponse)
|
||||||
assert response.type == "InvalidUserInput"
|
assert response.type == "InvalidUserInput"
|
||||||
assert response.code == HTTPStatus.BAD_REQUEST
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
@ -67,43 +65,43 @@ async def test_load_lora_adapter_missing_fields():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_load_lora_adapter_duplicate():
|
async def test_load_lora_adapter_duplicate():
|
||||||
serving_engine = await _async_serving_engine_init()
|
serving_models = await _async_serving_models_init()
|
||||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||||
lora_path="/path/to/adapter1")
|
lora_path="/path/to/adapter1")
|
||||||
response = await serving_engine.load_lora_adapter(request)
|
response = await serving_models.load_lora_adapter(request)
|
||||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
|
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
|
||||||
lora_name='adapter1')
|
lora_name='adapter1')
|
||||||
assert len(serving_engine.lora_requests) == 1
|
assert len(serving_models.lora_requests) == 1
|
||||||
|
|
||||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||||
lora_path="/path/to/adapter1")
|
lora_path="/path/to/adapter1")
|
||||||
response = await serving_engine.load_lora_adapter(request)
|
response = await serving_models.load_lora_adapter(request)
|
||||||
assert isinstance(response, ErrorResponse)
|
assert isinstance(response, ErrorResponse)
|
||||||
assert response.type == "InvalidUserInput"
|
assert response.type == "InvalidUserInput"
|
||||||
assert response.code == HTTPStatus.BAD_REQUEST
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
assert len(serving_engine.lora_requests) == 1
|
assert len(serving_models.lora_requests) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unload_lora_adapter_success():
|
async def test_unload_lora_adapter_success():
|
||||||
serving_engine = await _async_serving_engine_init()
|
serving_models = await _async_serving_models_init()
|
||||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||||
lora_path="/path/to/adapter1")
|
lora_path="/path/to/adapter1")
|
||||||
response = await serving_engine.load_lora_adapter(request)
|
response = await serving_models.load_lora_adapter(request)
|
||||||
assert len(serving_engine.lora_requests) == 1
|
assert len(serving_models.lora_requests) == 1
|
||||||
|
|
||||||
request = UnloadLoraAdapterRequest(lora_name="adapter1")
|
request = UnloadLoraAdapterRequest(lora_name="adapter1")
|
||||||
response = await serving_engine.unload_lora_adapter(request)
|
response = await serving_models.unload_lora_adapter(request)
|
||||||
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
|
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
|
||||||
lora_name='adapter1')
|
lora_name='adapter1')
|
||||||
assert len(serving_engine.lora_requests) == 0
|
assert len(serving_models.lora_requests) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unload_lora_adapter_missing_fields():
|
async def test_unload_lora_adapter_missing_fields():
|
||||||
serving_engine = await _async_serving_engine_init()
|
serving_models = await _async_serving_models_init()
|
||||||
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
|
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
|
||||||
response = await serving_engine.unload_lora_adapter(request)
|
response = await serving_models.unload_lora_adapter(request)
|
||||||
assert isinstance(response, ErrorResponse)
|
assert isinstance(response, ErrorResponse)
|
||||||
assert response.type == "InvalidUserInput"
|
assert response.type == "InvalidUserInput"
|
||||||
assert response.code == HTTPStatus.BAD_REQUEST
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
@ -111,9 +109,9 @@ async def test_unload_lora_adapter_missing_fields():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unload_lora_adapter_not_found():
|
async def test_unload_lora_adapter_not_found():
|
||||||
serving_engine = await _async_serving_engine_init()
|
serving_models = await _async_serving_models_init()
|
||||||
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
||||||
response = await serving_engine.unload_lora_adapter(request)
|
response = await serving_models.unload_lora_adapter(request)
|
||||||
assert isinstance(response, ErrorResponse)
|
assert isinstance(response, ErrorResponse)
|
||||||
assert response.type == "InvalidUserInput"
|
assert response.type == "InvalidUserInput"
|
||||||
assert response.code == HTTPStatus.BAD_REQUEST
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
@ -58,7 +58,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
|
OpenAIServingModels)
|
||||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
@ -269,6 +271,10 @@ def base(request: Request) -> OpenAIServing:
|
|||||||
return tokenization(request)
|
return tokenization(request)
|
||||||
|
|
||||||
|
|
||||||
|
def models(request: Request) -> OpenAIServingModels:
|
||||||
|
return request.app.state.openai_serving_models
|
||||||
|
|
||||||
|
|
||||||
def chat(request: Request) -> Optional[OpenAIServingChat]:
|
def chat(request: Request) -> Optional[OpenAIServingChat]:
|
||||||
return request.app.state.openai_serving_chat
|
return request.app.state.openai_serving_chat
|
||||||
|
|
||||||
@ -336,10 +342,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
async def show_available_models(raw_request: Request):
|
async def show_available_models(raw_request: Request):
|
||||||
handler = base(raw_request)
|
handler = models(raw_request)
|
||||||
|
|
||||||
models = await handler.show_available_models()
|
models_ = await handler.show_available_models()
|
||||||
return JSONResponse(content=models.model_dump())
|
return JSONResponse(content=models_.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@router.get("/version")
|
@router.get("/version")
|
||||||
@ -505,9 +511,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||||||
@router.post("/v1/load_lora_adapter")
|
@router.post("/v1/load_lora_adapter")
|
||||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
for route in [chat, completion, embedding]:
|
handler = models(raw_request)
|
||||||
handler = route(raw_request)
|
|
||||||
if handler is not None:
|
|
||||||
response = await handler.load_lora_adapter(request)
|
response = await handler.load_lora_adapter(request)
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
@ -518,9 +522,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||||||
@router.post("/v1/unload_lora_adapter")
|
@router.post("/v1/unload_lora_adapter")
|
||||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
for route in [chat, completion, embedding]:
|
handler = models(raw_request)
|
||||||
handler = route(raw_request)
|
|
||||||
if handler is not None:
|
|
||||||
response = await handler.unload_lora_adapter(request)
|
response = await handler.unload_lora_adapter(request)
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
@ -628,13 +630,18 @@ def init_app_state(
|
|||||||
resolved_chat_template = load_chat_template(args.chat_template)
|
resolved_chat_template = load_chat_template(args.chat_template)
|
||||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||||
|
|
||||||
|
state.openai_serving_models = OpenAIServingModels(
|
||||||
|
model_config=model_config,
|
||||||
|
base_model_paths=base_model_paths,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
)
|
||||||
|
# TODO: The chat template is now broken for lora adapters :(
|
||||||
state.openai_serving_chat = OpenAIServingChat(
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
state.openai_serving_models,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
lora_modules=args.lora_modules,
|
|
||||||
prompt_adapters=args.prompt_adapters,
|
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
@ -646,16 +653,14 @@ def init_app_state(
|
|||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
state.openai_serving_models,
|
||||||
lora_modules=args.lora_modules,
|
|
||||||
prompt_adapters=args.prompt_adapters,
|
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
) if model_config.runner_type == "generate" else None
|
) if model_config.runner_type == "generate" else None
|
||||||
state.openai_serving_pooling = OpenAIServingPooling(
|
state.openai_serving_pooling = OpenAIServingPooling(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
state.openai_serving_models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
@ -663,7 +668,7 @@ def init_app_state(
|
|||||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
state.openai_serving_models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
@ -671,14 +676,13 @@ def init_app_state(
|
|||||||
state.openai_serving_scores = OpenAIServingScores(
|
state.openai_serving_scores = OpenAIServingScores(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
state.openai_serving_models,
|
||||||
request_logger=request_logger
|
request_logger=request_logger
|
||||||
) if model_config.task == "score" else None
|
) if model_config.task == "score" else None
|
||||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
state.openai_serving_models,
|
||||||
lora_modules=args.lora_modules,
|
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from typing import List, Optional, Sequence, Union, get_args
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
validate_chat_template)
|
validate_chat_template)
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|||||||
@ -20,7 +20,8 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
|
OpenAIServingModels)
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
@ -213,13 +214,17 @@ async def main(args):
|
|||||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
# Create the openai serving objects.
|
# Create the openai serving objects.
|
||||||
|
openai_serving_models = OpenAIServingModels(
|
||||||
|
model_config=model_config,
|
||||||
|
base_model_paths=base_model_paths,
|
||||||
|
lora_modules=None,
|
||||||
|
prompt_adapters=None,
|
||||||
|
)
|
||||||
openai_serving_chat = OpenAIServingChat(
|
openai_serving_chat = OpenAIServingChat(
|
||||||
engine,
|
engine,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
openai_serving_models,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
lora_modules=None,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
@ -228,7 +233,7 @@ async def main(args):
|
|||||||
openai_serving_embedding = OpenAIServingEmbedding(
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine,
|
engine,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
openai_serving_models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
|
|||||||
@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
OpenAIServing,
|
|
||||||
PromptAdapterPath)
|
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
@ -42,11 +40,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
models: OpenAIServingModels,
|
||||||
response_role: str,
|
response_role: str,
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
|
||||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
@ -57,9 +53,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
models=models,
|
||||||
lora_modules=lora_modules,
|
|
||||||
prompt_adapters=prompt_adapters,
|
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
|
|
||||||
@ -126,7 +120,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
model_name = self._get_model_name(lora_request)
|
model_name = self.models.model_name(lora_request)
|
||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
|||||||
@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
OpenAIServing,
|
|
||||||
PromptAdapterPath)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
@ -41,18 +39,14 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
|
||||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
models=models,
|
||||||
lora_modules=lora_modules,
|
|
||||||
prompt_adapters=prompt_adapters,
|
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||||
@ -170,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
result_generator = merge_async_iterators(*generators)
|
result_generator = merge_async_iterators(*generators)
|
||||||
|
|
||||||
model_name = self._get_model_name(lora_request)
|
model_name = self.models.model_name(lora_request)
|
||||||
num_prompts = len(engine_prompts)
|
num_prompts = len(engine_prompts)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
|
|||||||
@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
|||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
EmbeddingResponseData,
|
EmbeddingResponseData,
|
||||||
ErrorResponse, UsageInfo)
|
ErrorResponse, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||||
PoolingRequestOutput)
|
PoolingRequestOutput)
|
||||||
@ -46,7 +47,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
@ -54,9 +55,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
models=models,
|
||||||
lora_modules=None,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
|
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import pathlib
|
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||||
Optional, Sequence, Tuple, TypedDict, Union)
|
Optional, Sequence, Tuple, TypedDict, Union)
|
||||||
@ -28,13 +26,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
EmbeddingChatRequest,
|
EmbeddingChatRequest,
|
||||||
EmbeddingCompletionRequest,
|
EmbeddingCompletionRequest,
|
||||||
ErrorResponse,
|
ErrorResponse, ScoreRequest,
|
||||||
LoadLoraAdapterRequest,
|
|
||||||
ModelCard, ModelList,
|
|
||||||
ModelPermission, ScoreRequest,
|
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest,
|
TokenizeCompletionRequest)
|
||||||
UnloadLoraAdapterRequest)
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs import TokensPrompt
|
||||||
@ -48,30 +43,10 @@ from vllm.sequence import Logprob
|
|||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
log_tracing_disabled_warning)
|
log_tracing_disabled_warning)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid
|
from vllm.utils import is_list_of, make_async, random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseModelPath:
|
|
||||||
name: str
|
|
||||||
model_path: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PromptAdapterPath:
|
|
||||||
name: str
|
|
||||||
local_path: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoRAModulePath:
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
base_model_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||||
EmbeddingCompletionRequest, ScoreRequest,
|
EmbeddingCompletionRequest, ScoreRequest,
|
||||||
TokenizeCompletionRequest]
|
TokenizeCompletionRequest]
|
||||||
@ -96,10 +71,8 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
|
||||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
):
|
):
|
||||||
@ -109,35 +82,7 @@ class OpenAIServing:
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
|
|
||||||
self.base_model_paths = base_model_paths
|
self.models = models
|
||||||
|
|
||||||
self.lora_id_counter = AtomicCounter(0)
|
|
||||||
self.lora_requests = []
|
|
||||||
if lora_modules is not None:
|
|
||||||
self.lora_requests = [
|
|
||||||
LoRARequest(lora_name=lora.name,
|
|
||||||
lora_int_id=i,
|
|
||||||
lora_path=lora.path,
|
|
||||||
base_model_name=lora.base_model_name
|
|
||||||
if lora.base_model_name
|
|
||||||
and self._is_model_supported(lora.base_model_name)
|
|
||||||
else self.base_model_paths[0].name)
|
|
||||||
for i, lora in enumerate(lora_modules, start=1)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.prompt_adapter_requests = []
|
|
||||||
if prompt_adapters is not None:
|
|
||||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
|
||||||
with pathlib.Path(prompt_adapter.local_path,
|
|
||||||
"adapter_config.json").open() as f:
|
|
||||||
adapter_config = json.load(f)
|
|
||||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
|
||||||
self.prompt_adapter_requests.append(
|
|
||||||
PromptAdapterRequest(
|
|
||||||
prompt_adapter_name=prompt_adapter.name,
|
|
||||||
prompt_adapter_id=i,
|
|
||||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
|
||||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
|
||||||
|
|
||||||
self.request_logger = request_logger
|
self.request_logger = request_logger
|
||||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||||
@ -150,33 +95,6 @@ class OpenAIServing:
|
|||||||
self._tokenize_prompt_input_or_inputs,
|
self._tokenize_prompt_input_or_inputs,
|
||||||
executor=self._tokenizer_executor)
|
executor=self._tokenizer_executor)
|
||||||
|
|
||||||
async def show_available_models(self) -> ModelList:
|
|
||||||
"""Show available models. Right now we only have one model."""
|
|
||||||
model_cards = [
|
|
||||||
ModelCard(id=base_model.name,
|
|
||||||
max_model_len=self.max_model_len,
|
|
||||||
root=base_model.model_path,
|
|
||||||
permission=[ModelPermission()])
|
|
||||||
for base_model in self.base_model_paths
|
|
||||||
]
|
|
||||||
lora_cards = [
|
|
||||||
ModelCard(id=lora.lora_name,
|
|
||||||
root=lora.local_path,
|
|
||||||
parent=lora.base_model_name if lora.base_model_name else
|
|
||||||
self.base_model_paths[0].name,
|
|
||||||
permission=[ModelPermission()])
|
|
||||||
for lora in self.lora_requests
|
|
||||||
]
|
|
||||||
prompt_adapter_cards = [
|
|
||||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
|
||||||
root=self.base_model_paths[0].name,
|
|
||||||
permission=[ModelPermission()])
|
|
||||||
for prompt_adapter in self.prompt_adapter_requests
|
|
||||||
]
|
|
||||||
model_cards.extend(lora_cards)
|
|
||||||
model_cards.extend(prompt_adapter_cards)
|
|
||||||
return ModelList(data=model_cards)
|
|
||||||
|
|
||||||
def create_error_response(
|
def create_error_response(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@ -205,11 +123,13 @@ class OpenAIServing:
|
|||||||
) -> Optional[ErrorResponse]:
|
) -> Optional[ErrorResponse]:
|
||||||
if self._is_model_supported(request.model):
|
if self._is_model_supported(request.model):
|
||||||
return None
|
return None
|
||||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
if request.model in [
|
||||||
|
lora.lora_name for lora in self.models.lora_requests
|
||||||
|
]:
|
||||||
return None
|
return None
|
||||||
if request.model in [
|
if request.model in [
|
||||||
prompt_adapter.prompt_adapter_name
|
prompt_adapter.prompt_adapter_name
|
||||||
for prompt_adapter in self.prompt_adapter_requests
|
for prompt_adapter in self.models.prompt_adapter_requests
|
||||||
]:
|
]:
|
||||||
return None
|
return None
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
@ -223,10 +143,10 @@ class OpenAIServing:
|
|||||||
None, PromptAdapterRequest]]:
|
None, PromptAdapterRequest]]:
|
||||||
if self._is_model_supported(request.model):
|
if self._is_model_supported(request.model):
|
||||||
return None, None
|
return None, None
|
||||||
for lora in self.lora_requests:
|
for lora in self.models.lora_requests:
|
||||||
if request.model == lora.lora_name:
|
if request.model == lora.lora_name:
|
||||||
return lora, None
|
return lora, None
|
||||||
for prompt_adapter in self.prompt_adapter_requests:
|
for prompt_adapter in self.models.prompt_adapter_requests:
|
||||||
if request.model == prompt_adapter.prompt_adapter_name:
|
if request.model == prompt_adapter.prompt_adapter_name:
|
||||||
return None, prompt_adapter
|
return None, prompt_adapter
|
||||||
# if _check_model has been called earlier, this will be unreachable
|
# if _check_model has been called earlier, this will be unreachable
|
||||||
@ -588,91 +508,5 @@ class OpenAIServing:
|
|||||||
return logprob.decoded_token
|
return logprob.decoded_token
|
||||||
return tokenizer.decode(token_id)
|
return tokenizer.decode(token_id)
|
||||||
|
|
||||||
async def _check_load_lora_adapter_request(
|
|
||||||
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
|
||||||
# Check if both 'lora_name' and 'lora_path' are provided
|
|
||||||
if not request.lora_name or not request.lora_path:
|
|
||||||
return self.create_error_response(
|
|
||||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
|
||||||
err_type="InvalidUserInput",
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
# Check if the lora adapter with the given name already exists
|
|
||||||
if any(lora_request.lora_name == request.lora_name
|
|
||||||
for lora_request in self.lora_requests):
|
|
||||||
return self.create_error_response(
|
|
||||||
message=
|
|
||||||
f"The lora adapter '{request.lora_name}' has already been"
|
|
||||||
"loaded.",
|
|
||||||
err_type="InvalidUserInput",
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _check_unload_lora_adapter_request(
|
|
||||||
self,
|
|
||||||
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
|
||||||
# Check if either 'lora_name' or 'lora_int_id' is provided
|
|
||||||
if not request.lora_name and not request.lora_int_id:
|
|
||||||
return self.create_error_response(
|
|
||||||
message=
|
|
||||||
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
|
||||||
err_type="InvalidUserInput",
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
# Check if the lora adapter with the given name exists
|
|
||||||
if not any(lora_request.lora_name == request.lora_name
|
|
||||||
for lora_request in self.lora_requests):
|
|
||||||
return self.create_error_response(
|
|
||||||
message=
|
|
||||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
|
||||||
err_type="InvalidUserInput",
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def load_lora_adapter(
|
|
||||||
self,
|
|
||||||
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
|
||||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
|
||||||
if error_check_ret is not None:
|
|
||||||
return error_check_ret
|
|
||||||
|
|
||||||
lora_name, lora_path = request.lora_name, request.lora_path
|
|
||||||
unique_id = self.lora_id_counter.inc(1)
|
|
||||||
self.lora_requests.append(
|
|
||||||
LoRARequest(lora_name=lora_name,
|
|
||||||
lora_int_id=unique_id,
|
|
||||||
lora_path=lora_path))
|
|
||||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
|
||||||
|
|
||||||
async def unload_lora_adapter(
|
|
||||||
self,
|
|
||||||
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
|
||||||
error_check_ret = await self._check_unload_lora_adapter_request(request
|
|
||||||
)
|
|
||||||
if error_check_ret is not None:
|
|
||||||
return error_check_ret
|
|
||||||
|
|
||||||
lora_name = request.lora_name
|
|
||||||
self.lora_requests = [
|
|
||||||
lora_request for lora_request in self.lora_requests
|
|
||||||
if lora_request.lora_name != lora_name
|
|
||||||
]
|
|
||||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
|
||||||
|
|
||||||
def _is_model_supported(self, model_name):
|
def _is_model_supported(self, model_name):
|
||||||
return any(model.name == model_name for model in self.base_model_paths)
|
return self.models.is_base_model(model_name)
|
||||||
|
|
||||||
def _get_model_name(self, lora: Optional[LoRARequest]):
|
|
||||||
"""
|
|
||||||
Returns the appropriate model name depending on the availability
|
|
||||||
and support of the LoRA or base model.
|
|
||||||
Parameters:
|
|
||||||
- lora: LoRARequest that contain a base_model_name.
|
|
||||||
Returns:
|
|
||||||
- str: The name of the base model or the first available model path.
|
|
||||||
"""
|
|
||||||
if lora is not None:
|
|
||||||
return lora.lora_name
|
|
||||||
return self.base_model_paths[0].name
|
|
||||||
|
|||||||
210
vllm/entrypoints/openai/serving_models.py
Normal file
210
vllm/entrypoints/openai/serving_models.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
|
LoadLoraAdapterRequest,
|
||||||
|
ModelCard, ModelList,
|
||||||
|
ModelPermission,
|
||||||
|
UnloadLoraAdapterRequest)
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
|
from vllm.utils import AtomicCounter
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelPath:
|
||||||
|
name: str
|
||||||
|
model_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptAdapterPath:
|
||||||
|
name: str
|
||||||
|
local_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAModulePath:
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
base_model_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingModels:
|
||||||
|
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||||
|
|
||||||
|
Handles the routes:
|
||||||
|
- /v1/models
|
||||||
|
- /v1/load_lora_adapter
|
||||||
|
- /v1/unload_lora_adapter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
base_model_paths: List[BaseModelPath],
|
||||||
|
*,
|
||||||
|
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||||
|
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.base_model_paths = base_model_paths
|
||||||
|
self.max_model_len = model_config.max_model_len
|
||||||
|
|
||||||
|
self.lora_id_counter = AtomicCounter(0)
|
||||||
|
self.lora_requests = []
|
||||||
|
if lora_modules is not None:
|
||||||
|
self.lora_requests = [
|
||||||
|
LoRARequest(lora_name=lora.name,
|
||||||
|
lora_int_id=i,
|
||||||
|
lora_path=lora.path,
|
||||||
|
base_model_name=lora.base_model_name
|
||||||
|
if lora.base_model_name
|
||||||
|
and self.is_base_model(lora.base_model_name) else
|
||||||
|
self.base_model_paths[0].name)
|
||||||
|
for i, lora in enumerate(lora_modules, start=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.prompt_adapter_requests = []
|
||||||
|
if prompt_adapters is not None:
|
||||||
|
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||||
|
with pathlib.Path(prompt_adapter.local_path,
|
||||||
|
"adapter_config.json").open() as f:
|
||||||
|
adapter_config = json.load(f)
|
||||||
|
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||||
|
self.prompt_adapter_requests.append(
|
||||||
|
PromptAdapterRequest(
|
||||||
|
prompt_adapter_name=prompt_adapter.name,
|
||||||
|
prompt_adapter_id=i,
|
||||||
|
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||||
|
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||||
|
|
||||||
|
def is_base_model(self, model_name):
|
||||||
|
return any(model.name == model_name for model in self.base_model_paths)
|
||||||
|
|
||||||
|
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
|
||||||
|
"""Returns the appropriate model name depending on the availability
|
||||||
|
and support of the LoRA or base model.
|
||||||
|
Parameters:
|
||||||
|
- lora: LoRARequest that contain a base_model_name.
|
||||||
|
Returns:
|
||||||
|
- str: The name of the base model or the first available model path.
|
||||||
|
"""
|
||||||
|
if lora_request is not None:
|
||||||
|
return lora_request.lora_name
|
||||||
|
return self.base_model_paths[0].name
|
||||||
|
|
||||||
|
async def show_available_models(self) -> ModelList:
|
||||||
|
"""Show available models. This includes the base model and all
|
||||||
|
adapters"""
|
||||||
|
model_cards = [
|
||||||
|
ModelCard(id=base_model.name,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
root=base_model.model_path,
|
||||||
|
permission=[ModelPermission()])
|
||||||
|
for base_model in self.base_model_paths
|
||||||
|
]
|
||||||
|
lora_cards = [
|
||||||
|
ModelCard(id=lora.lora_name,
|
||||||
|
root=lora.local_path,
|
||||||
|
parent=lora.base_model_name if lora.base_model_name else
|
||||||
|
self.base_model_paths[0].name,
|
||||||
|
permission=[ModelPermission()])
|
||||||
|
for lora in self.lora_requests
|
||||||
|
]
|
||||||
|
prompt_adapter_cards = [
|
||||||
|
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||||
|
root=self.base_model_paths[0].name,
|
||||||
|
permission=[ModelPermission()])
|
||||||
|
for prompt_adapter in self.prompt_adapter_requests
|
||||||
|
]
|
||||||
|
model_cards.extend(lora_cards)
|
||||||
|
model_cards.extend(prompt_adapter_cards)
|
||||||
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
async def load_lora_adapter(
|
||||||
|
self,
|
||||||
|
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||||
|
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
lora_name, lora_path = request.lora_name, request.lora_path
|
||||||
|
unique_id = self.lora_id_counter.inc(1)
|
||||||
|
self.lora_requests.append(
|
||||||
|
LoRARequest(lora_name=lora_name,
|
||||||
|
lora_int_id=unique_id,
|
||||||
|
lora_path=lora_path))
|
||||||
|
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||||
|
|
||||||
|
async def unload_lora_adapter(
|
||||||
|
self,
|
||||||
|
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||||
|
error_check_ret = await self._check_unload_lora_adapter_request(request
|
||||||
|
)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
lora_name = request.lora_name
|
||||||
|
self.lora_requests = [
|
||||||
|
lora_request for lora_request in self.lora_requests
|
||||||
|
if lora_request.lora_name != lora_name
|
||||||
|
]
|
||||||
|
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||||
|
|
||||||
|
async def _check_load_lora_adapter_request(
|
||||||
|
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||||
|
# Check if both 'lora_name' and 'lora_path' are provided
|
||||||
|
if not request.lora_name or not request.lora_path:
|
||||||
|
return create_error_response(
|
||||||
|
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
# Check if the lora adapter with the given name already exists
|
||||||
|
if any(lora_request.lora_name == request.lora_name
|
||||||
|
for lora_request in self.lora_requests):
|
||||||
|
return create_error_response(
|
||||||
|
message=
|
||||||
|
f"The lora adapter '{request.lora_name}' has already been"
|
||||||
|
"loaded.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _check_unload_lora_adapter_request(
|
||||||
|
self,
|
||||||
|
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||||
|
# Check if either 'lora_name' or 'lora_int_id' is provided
|
||||||
|
if not request.lora_name and not request.lora_int_id:
|
||||||
|
return create_error_response(
|
||||||
|
message=
|
||||||
|
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
# Check if the lora adapter with the given name exists
|
||||||
|
if not any(lora_request.lora_name == request.lora_name
|
||||||
|
for lora_request in self.lora_requests):
|
||||||
|
return create_error_response(
|
||||||
|
message=
|
||||||
|
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_response(
|
||||||
|
message: str,
|
||||||
|
err_type: str = "BadRequestError",
|
||||||
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||||
|
return ErrorResponse(message=message,
|
||||||
|
type=err_type,
|
||||||
|
code=status_code.value)
|
||||||
@ -15,7 +15,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
|||||||
PoolingChatRequest,
|
PoolingChatRequest,
|
||||||
PoolingRequest, PoolingResponse,
|
PoolingRequest, PoolingResponse,
|
||||||
PoolingResponseData, UsageInfo)
|
PoolingResponseData, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||||
from vllm.utils import merge_async_iterators
|
from vllm.utils import merge_async_iterators
|
||||||
@ -44,7 +45,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
@ -52,9 +53,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
models=models,
|
||||||
lora_modules=None,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
|
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
|
|||||||
@ -10,7 +10,8 @@ from vllm.entrypoints.logger import RequestLogger
|
|||||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||||
ScoreResponse, ScoreResponseData,
|
ScoreResponse, ScoreResponseData,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.inputs.data import TokensPrompt
|
from vllm.inputs.data import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||||
@ -50,15 +51,13 @@ class OpenAIServingScores(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
models=models,
|
||||||
lora_modules=None,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
|
|
||||||
async def create_score(
|
async def create_score(
|
||||||
|
|||||||
@ -15,9 +15,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
|||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse)
|
TokenizeResponse)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
OpenAIServing)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -29,18 +28,15 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
models=models,
|
||||||
lora_modules=lora_modules,
|
|
||||||
prompt_adapters=None,
|
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
|
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user