[Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357)

This commit is contained in:
akxxsb 2023-07-05 12:39:59 +08:00 committed by GitHub
parent 98fe8cb542
commit 3d64cf019e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,8 +13,9 @@ from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import (Conversation, SeparatorStyle, from fastchat.conversation import Conversation, SeparatorStyle
get_conv_template) from fastchat.model.model_adapter import get_conversation_template
import uvicorn import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -36,7 +37,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
chat_template = None
app = fastapi.FastAPI() app = fastapi.FastAPI()
@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
async def get_gen_prompt(request) -> str: async def get_gen_prompt(request) -> str:
conv = get_conv_template(chat_template) conv = get_conversation_template(request.model)
conv = Conversation( conv = Conversation(
name=conv.name, name=conv.name,
system=conv.system, system=conv.system,
@ -560,14 +560,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not " help="The model name used in the API. If not "
"specified, the model name will be the same as " "specified, the model name will be the same as "
"the huggingface name.") "the huggingface name.")
parser.add_argument(
"--chat-template",
type=str,
default=None,
help="The chat template name used in the ChatCompletion endpoint. If "
"not specified, we use the API model name as the template name. See "
"https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
"for the list of available templates.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
@ -586,11 +579,6 @@ if __name__ == "__main__":
else: else:
served_model = args.model served_model = args.model
if args.chat_template is not None:
chat_template = args.chat_template
else:
chat_template = served_model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config()) engine_model_config = asyncio.run(engine.get_model_config())