mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
feat: add ChatCompletion endpoint in OpenAI demo server. (#330)
This commit is contained in:
parent
dafd924c1f
commit
49b26e2cec
@ -4,7 +4,7 @@ import argparse
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import BackgroundTasks, Request
|
from fastapi import BackgroundTasks, Request
|
||||||
@ -17,8 +17,12 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||||
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
|
CompletionResponseStreamChoice, CompletionStreamResponse,
|
||||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
|
||||||
|
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
|
||||||
|
ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||||
|
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
async def get_gen_prompt(request) -> str:
|
||||||
|
conv = get_conv_template(request.model)
|
||||||
|
conv = Conversation(
|
||||||
|
name=conv.name,
|
||||||
|
system=conv.system,
|
||||||
|
roles=conv.roles,
|
||||||
|
messages=list(conv.messages), # prevent in-place modification
|
||||||
|
offset=conv.offset,
|
||||||
|
sep_style=SeparatorStyle(conv.sep_style),
|
||||||
|
sep=conv.sep,
|
||||||
|
sep2=conv.sep2,
|
||||||
|
stop_str=conv.stop_str,
|
||||||
|
stop_token_ids=conv.stop_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(request.messages, str):
|
||||||
|
prompt = request.messages
|
||||||
|
else:
|
||||||
|
for message in request.messages:
|
||||||
|
msg_role = message["role"]
|
||||||
|
if msg_role == "system":
|
||||||
|
conv.system = message["content"]
|
||||||
|
elif msg_role == "user":
|
||||||
|
conv.append_message(conv.roles[0], message["content"])
|
||||||
|
elif msg_role == "assistant":
|
||||||
|
conv.append_message(conv.roles[1], message["content"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role: {msg_role}")
|
||||||
|
|
||||||
|
# Add a blank message for the assistant.
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
prompt = conv.get_prompt()
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
async def check_length(request, prompt, engine):
|
||||||
|
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
|
||||||
|
context_len = engine.engine.model_config.hf_config.max_sequence_length
|
||||||
|
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
||||||
|
context_len = engine.engine.model_config.hf_config.seq_length
|
||||||
|
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
|
||||||
|
context_len = engine.engine.model_config.hf_config.max_position_embeddings
|
||||||
|
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
||||||
|
context_len = engine.engine.model_config.hf_config.seq_length
|
||||||
|
else:
|
||||||
|
context_len = 2048
|
||||||
|
|
||||||
|
input_ids = tokenizer(prompt).input_ids
|
||||||
|
token_num = len(input_ids)
|
||||||
|
|
||||||
|
if token_num + request.max_tokens > context_len:
|
||||||
|
return create_error_response(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
f"This model's maximum context length is {context_len} tokens. "
|
||||||
|
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||||
|
f"({token_num} in the messages, "
|
||||||
|
f"{request.max_tokens} in the completion). "
|
||||||
|
f"Please reduce the length of the messages or completion.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
async def show_available_models():
|
async def show_available_models():
|
||||||
"""Show available models. Right now we only have one model."""
|
"""Show available models. Right now we only have one model."""
|
||||||
@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int],
|
|||||||
return logprobs
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def create_chat_completion(raw_request: Request):
|
||||||
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
||||||
|
|
||||||
|
NOTE: Currently we do not support the following features:
|
||||||
|
- function_call (Users should implement this by themselves)
|
||||||
|
- logit_bias (to be supported by vLLM engine)
|
||||||
|
"""
|
||||||
|
request = ChatCompletionRequest(**await raw_request.json())
|
||||||
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
|
||||||
|
error_check_ret = await check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
if request.logit_bias is not None:
|
||||||
|
# TODO: support logit_bias in vLLM engine.
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
|
prompt = await get_gen_prompt(request)
|
||||||
|
error_check_ret = await check_length(request, prompt, engine)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
model_name = request.model
|
||||||
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
|
created_time = int(time.time())
|
||||||
|
try:
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=request.n,
|
||||||
|
presence_penalty=request.presence_penalty,
|
||||||
|
frequency_penalty=request.frequency_penalty,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
stop=request.stop,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
best_of=request.best_of,
|
||||||
|
top_k=request.top_k,
|
||||||
|
ignore_eos=request.ignore_eos,
|
||||||
|
use_beam_search=request.use_beam_search,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
|
result_generator = engine.generate(prompt, sampling_params,
|
||||||
|
request_id)
|
||||||
|
|
||||||
|
async def abort_request() -> None:
|
||||||
|
await engine.abort(request_id)
|
||||||
|
|
||||||
|
def create_stream_response_json(index: int,
|
||||||
|
text: str,
|
||||||
|
finish_reason: Optional[str] = None) -> str:
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=DeltaMessage(content=text),
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
response = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=[choice_data],
|
||||||
|
)
|
||||||
|
response_json = response.json(ensure_ascii=False)
|
||||||
|
|
||||||
|
return response_json
|
||||||
|
|
||||||
|
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
|
# First chunk with role
|
||||||
|
for i in range(request.n):
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(role="assistant"),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id, choices=[choice_data], model=model_name
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
previous_texts = [""] * request.n
|
||||||
|
previous_num_tokens = [0] * request.n
|
||||||
|
async for res in result_generator:
|
||||||
|
res: RequestOutput
|
||||||
|
for output in res.outputs:
|
||||||
|
i = output.index
|
||||||
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
|
previous_texts[i] = output.text
|
||||||
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
|
response_json = create_stream_response_json(
|
||||||
|
index=i,
|
||||||
|
text=delta_text,
|
||||||
|
)
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
if output.finish_reason is not None:
|
||||||
|
response_json = create_stream_response_json(
|
||||||
|
index=i,
|
||||||
|
text="",
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
# Streaming response
|
||||||
|
if request.stream:
|
||||||
|
background_tasks = BackgroundTasks()
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
background_tasks.add_task(abort_request)
|
||||||
|
return StreamingResponse(completion_stream_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=background_tasks)
|
||||||
|
|
||||||
|
# Non-streaming response
|
||||||
|
final_res: RequestOutput = None
|
||||||
|
async for res in result_generator:
|
||||||
|
if await raw_request.is_disconnected():
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
await abort_request()
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
|
"Client disconnected")
|
||||||
|
final_res = res
|
||||||
|
assert final_res is not None
|
||||||
|
choices = []
|
||||||
|
for output in final_res.outputs:
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
message=ChatMessage(role="assistant", content=output.text),
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
|
num_generated_tokens = sum(len(output.token_ids)
|
||||||
|
for output in final_res.outputs)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_generated_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
|
)
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
# When user requests streaming but we don't stream, we still need to
|
||||||
|
# return a streaming response with a single event.
|
||||||
|
response_json = response.json(ensure_ascii=False)
|
||||||
|
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return StreamingResponse(fake_stream_generator(),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def create_completion(raw_request: Request):
|
async def create_completion(raw_request: Request):
|
||||||
"""Completion API similar to OpenAI's API.
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|||||||
@ -53,16 +53,22 @@ class UsageInfo(BaseModel):
|
|||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Dict[str, str]]
|
messages: Union[str, List[Dict[str, str]]]
|
||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = 0.7
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = 16
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
frequency_penalty: Optional[float] = 0.0
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
# Additional parameters supported by vLLM
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
top_k: Optional[int] = -1
|
||||||
|
ignore_eos: Optional[bool] = False
|
||||||
|
use_beam_search: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
@ -124,3 +130,42 @@ class CompletionStreamResponse(BaseModel):
|
|||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[CompletionResponseStreamChoice]
|
choices: List[CompletionResponseStreamChoice]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
|
object: str = "chat.completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
|
object: str = "chat.completion.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user