From 3d64cf019e85385e33b522f65ad9e1e7c665c3e3 Mon Sep 17 00:00:00 2001 From: akxxsb Date: Wed, 5 Jul 2023 12:39:59 +0800 Subject: [PATCH] [Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357) --- vllm/entrypoints/openai/api_server.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2d1dcab3a862..b1a751d6e7f0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -13,8 +13,9 @@ from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -from fastchat.conversation import (Conversation, SeparatorStyle, - get_conv_template) +from fastchat.conversation import Conversation, SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + import uvicorn from vllm.engine.arg_utils import AsyncEngineArgs @@ -36,7 +37,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None -chat_template = None app = fastapi.FastAPI() @@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]: async def get_gen_prompt(request) -> str: - conv = get_conv_template(chat_template) + conv = get_conversation_template(request.model) conv = Conversation( name=conv.name, system=conv.system, @@ -560,14 +560,7 @@ if __name__ == "__main__": help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.") - parser.add_argument( - "--chat-template", - type=str, - default=None, - help="The chat template name used in the ChatCompletion endpoint. If " - "not specified, we use the API model name as the template name. See " - "https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py " - "for the list of available templates.") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -586,11 +579,6 @@ if __name__ == "__main__": else: served_model = args.model - if args.chat_template is not None: - chat_template = args.chat_template - else: - chat_template = served_model - engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config())