mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357)
This commit is contained in:
parent
98fe8cb542
commit
3d64cf019e
@ -13,8 +13,9 @@ from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastchat.conversation import (Conversation, SeparatorStyle,
|
||||
get_conv_template)
|
||||
from fastchat.conversation import Conversation, SeparatorStyle
|
||||
from fastchat.model.model_adapter import get_conversation_template
|
||||
|
||||
import uvicorn
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -36,7 +37,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
chat_template = None
|
||||
app = fastapi.FastAPI()
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
|
||||
|
||||
|
||||
async def get_gen_prompt(request) -> str:
|
||||
conv = get_conv_template(chat_template)
|
||||
conv = get_conversation_template(request.model)
|
||||
conv = Conversation(
|
||||
name=conv.name,
|
||||
system=conv.system,
|
||||
@ -560,14 +560,7 @@ if __name__ == "__main__":
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The chat template name used in the ChatCompletion endpoint. If "
|
||||
"not specified, we use the API model name as the template name. See "
|
||||
"https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
|
||||
"for the list of available templates.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -586,11 +579,6 @@ if __name__ == "__main__":
|
||||
else:
|
||||
served_model = args.model
|
||||
|
||||
if args.chat_template is not None:
|
||||
chat_template = args.chat_template
|
||||
else:
|
||||
chat_template = served_model
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user