[Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355)

This commit is contained in:
Cyrus Leung 2024-04-27 13:08:24 +08:00 committed by GitHub
parent 12628d3c78
commit 8947bc3c15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 113 additions and 55 deletions

View File

@ -8,6 +8,7 @@ py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
tokenizers >= 0.19.1 # Required for Llama 3. tokenizers >= 0.19.1 # Required for Llama 3.
fastapi fastapi
openai
uvicorn[standard] uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server. pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0 prometheus_client >= 0.18.0

View File

@ -21,7 +21,6 @@ pytest-rerunfailures
pytest-shard pytest-shard
httpx httpx
einops # required for MPT einops # required for MPT
openai
requests requests
ray ray
peft peft

View File

@ -15,6 +15,7 @@ import ray
import requests import requests
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded assert loaded == {"result": 2}, loaded
async def test_extra_fields(server, client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "system",
"content": "You are a helpful assistant.",
"extra_field": "0",
}], # type: ignore
temperature=0,
seed=0)
assert "extra_forbidden" in exc_info.value.message
async def test_guided_grammar(server, client: openai.AsyncOpenAI): async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """ simple_sql_grammar = """
start: select_statement start: select_statement

View File

@ -9,7 +9,7 @@ import json
import ssl import ssl
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.serving_engine import LoRA from vllm.entrypoints.openai.serving_engine import LoRAModulePath
class LoRAParserAction(argparse.Action): class LoRAParserAction(argparse.Action):
@ -18,7 +18,7 @@ class LoRAParserAction(argparse.Action):
lora_list = [] lora_list = []
for item in values: for item in values:
name, path = item.split('=') name, path = item.split('=')
lora_list.append(LoRA(name, path)) lora_list.append(LoRAModulePath(name, path))
setattr(namespace, self.dest, lora_list) setattr(namespace, self.dest, lora_list)

View File

@ -4,14 +4,20 @@ import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field, model_validator from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
class ErrorResponse(BaseModel): class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
class ErrorResponse(OpenAIBaseModel):
object: str = "error" object: str = "error"
message: str message: str
type: str type: str
@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
code: int code: int
class ModelPermission(BaseModel): class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission" object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
is_blocking: bool = False is_blocking: bool = False
class ModelCard(BaseModel): class ModelCard(OpenAIBaseModel):
id: str id: str
object: str = "model" object: str = "model"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@ -44,26 +50,26 @@ class ModelCard(BaseModel):
permission: List[ModelPermission] = Field(default_factory=list) permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel): class ModelList(OpenAIBaseModel):
object: str = "list" object: str = "list"
data: List[ModelCard] = Field(default_factory=list) data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel): class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class ResponseFormat(BaseModel): class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text" # type must be "json_object" or "text"
type: Literal["text", "json_object"] type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
messages: List[Dict[str, str]] messages: List[ChatCompletionMessageParam]
model: str model: str
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
@ -204,7 +210,7 @@ class ChatCompletionRequest(BaseModel):
return data return data
class CompletionRequest(BaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create # https://platform.openai.com/docs/api-reference/completions/create
model: str model: str
@ -343,19 +349,19 @@ class CompletionRequest(BaseModel):
return data return data
class LogProbs(BaseModel): class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(OpenAIBaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = Field( stop_reason: Optional[Union[int, str]] = Field(
default=None, default=None,
description=( description=(
"The stop string or token id that caused the completion " "The stop string or token id that caused the completion "
@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
) )
class CompletionResponse(BaseModel): class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion" object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
usage: UsageInfo usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = Field( stop_reason: Optional[Union[int, str]] = Field(
default=None, default=None,
description=( description=(
"The stop string or token id that caused the completion " "The stop string or token id that caused the completion "
@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
) )
class CompletionStreamResponse(BaseModel): class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion" object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = Field(default=None) usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(BaseModel): class ChatMessage(OpenAIBaseModel):
role: str role: str
content: str content: str
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = None stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion" object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
usage: UsageInfo usage: UsageInfo
class DeltaMessage(BaseModel): class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = None stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk" object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))

View File

@ -1,8 +1,11 @@
import codecs import codecs
import time import time
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
from fastapi import Request from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
@ -10,7 +13,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
@ -20,20 +24,41 @@ from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
class OpenAIServingChat(OpenAIServing): class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model_names: List[str], served_model_names: List[str],
response_role: str, response_role: str,
lora_modules: Optional[List[LoRA]] = None, lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template=None): chat_template: Optional[str] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template) self._load_chat_template(chat_template)
def _parse_chat_message_content(
self,
role: ChatCompletionRole,
content: Optional[Union[str,
Iterable[ChatCompletionContentPartParam]]],
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
if content is None:
return [], []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
raise NotImplementedError("Complex input not supported yet")
async def create_chat_completion( async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None], ) -> Union[ErrorResponse, AsyncGenerator[str, None],
@ -52,10 +77,19 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret return error_check_ret
try: try:
conversation: List[ConversationMessage] = []
for m in request.messages:
messages, _ = self._parse_chat_message_content(
m["role"], m["content"])
conversation.extend(messages)
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
conversation=request.messages, conversation=conversation,
tokenize=False, tokenize=False,
add_generation_prompt=request.add_generation_prompt) add_generation_prompt=request.add_generation_prompt,
)
except Exception as e: except Exception as e:
logger.error("Error in applying chat template from request: %s", e) logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@ -105,9 +139,8 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, request: ChatCompletionRequest, self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str result_generator: AsyncIterator[RequestOutput],
) -> Union[ErrorResponse, AsyncGenerator[str, None]]: request_id: str) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type = "chat.completion.chunk"
@ -252,7 +285,7 @@ class OpenAIServingChat(OpenAIServing):
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput = None final_res: Optional[RequestOutput] = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
@ -317,7 +350,7 @@ class OpenAIServingChat(OpenAIServing):
return response return response
def _load_chat_template(self, chat_template): def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer tokenizer = self.tokenizer
if chat_template is not None: if chat_template is not None:

View File

@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
LogProbs, UsageInfo) LogProbs, UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model_names: List[str], served_model_names: List[str],
lora_modules: Optional[List[LoRA]] = None): lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.time()) created_time = int(time.time())
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators = [] generators: List[AsyncIterator[RequestOutput]] = []
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts=len(prompts)) num_prompts=len(prompts))
# Non-streaming response # Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts) final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try: try:
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():

View File

@ -22,17 +22,15 @@ logger = init_logger(__name__)
@dataclass @dataclass
class LoRA: class LoRAModulePath:
name: str name: str
local_path: str local_path: str
class OpenAIServing: class OpenAIServing:
def __init__(self, def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
engine: AsyncLLMEngine, lora_modules: Optional[List[LoRAModulePath]]):
served_model_names: List[str],
lora_modules=Optional[List[LoRA]]):
self.engine = engine self.engine = engine
self.served_model_names = served_model_names self.served_model_names = served_model_names
if lora_modules is None: if lora_modules is None:
@ -158,7 +156,9 @@ class OpenAIServing:
}) })
return json_str return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
if request.model in [lora.lora_name for lora in self.lora_requests]: if request.model in [lora.lora_name for lora in self.lora_requests]:
@ -168,14 +168,16 @@ class OpenAIServing:
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]: def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[LoRARequest]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name:
return lora return lora
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,