[Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355)

This commit is contained in:
Cyrus Leung 2024-04-27 13:08:24 +08:00 committed by GitHub
parent 12628d3c78
commit 8947bc3c15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 113 additions and 55 deletions

View File

@ -8,6 +8,7 @@ py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0

View File

@ -21,7 +21,6 @@ pytest-rerunfailures
pytest-shard
httpx
einops # required for MPT
openai
requests
ray
peft

View File

@ -15,6 +15,7 @@ import ray
import requests
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded
async def test_extra_fields(server, client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "system",
"content": "You are a helpful assistant.",
"extra_field": "0",
}], # type: ignore
temperature=0,
seed=0)
assert "extra_forbidden" in exc_info.value.message
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement

View File

@ -9,7 +9,7 @@ import json
import ssl
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.serving_engine import LoRA
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
class LoRAParserAction(argparse.Action):
@ -18,7 +18,7 @@ class LoRAParserAction(argparse.Action):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
lora_list.append(LoRAModulePath(name, path))
setattr(namespace, self.dest, lora_list)

View File

@ -4,14 +4,20 @@ import time
from typing import Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, Field, model_validator
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
class ErrorResponse(BaseModel):
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
class ErrorResponse(OpenAIBaseModel):
object: str = "error"
message: str
type: str
@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
code: int
class ModelPermission(BaseModel):
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
is_blocking: bool = False
class ModelCard(BaseModel):
class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
@ -44,26 +50,26 @@ class ModelCard(BaseModel):
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel):
class ModelList(OpenAIBaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel):
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ResponseFormat(BaseModel):
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel):
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[Dict[str, str]]
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
@ -204,7 +210,7 @@ class ChatCompletionRequest(BaseModel):
return data
class CompletionRequest(BaseModel):
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
@ -343,19 +349,19 @@ class CompletionRequest(BaseModel):
return data
class LogProbs(BaseModel):
class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(BaseModel):
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
)
class CompletionResponse(BaseModel):
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
)
class CompletionStreamResponse(BaseModel):
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(BaseModel):
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
usage: UsageInfo
class DeltaMessage(BaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(BaseModel):
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))

View File

@ -1,8 +1,11 @@
import codecs
import time
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
@ -10,7 +13,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
@ -20,20 +24,41 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None):
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)
def _parse_chat_message_content(
self,
role: ChatCompletionRole,
content: Optional[Union[str,
Iterable[ChatCompletionContentPartParam]]],
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
if content is None:
return [], []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
raise NotImplementedError("Complex input not supported yet")
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
@ -52,10 +77,19 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret
try:
conversation: List[ConversationMessage] = []
for m in request.messages:
messages, _ = self._parse_chat_message_content(
m["role"], m["content"])
conversation.extend(messages)
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
add_generation_prompt=request.add_generation_prompt,
)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
@ -105,9 +139,8 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
@ -252,7 +285,7 @@ class OpenAIServingChat(OpenAIServing):
model_name = self.served_model_names[0]
created_time = int(time.time())
final_res: RequestOutput = None
final_res: Optional[RequestOutput] = None
async for res in result_generator:
if await raw_request.is_disconnected():
@ -317,7 +350,7 @@ class OpenAIServingChat(OpenAIServing):
return response
def _load_chat_template(self, chat_template):
def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer
if chat_template is not None:

View File

@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRA]] = None):
lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.time())
# Schedule the request and get the result generator.
generators = []
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts=len(prompts))
# Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts)
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():

View File

@ -22,17 +22,15 @@ logger = init_logger(__name__)
@dataclass
class LoRA:
class LoRAModulePath:
name: str
local_path: str
class OpenAIServing:
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules=Optional[List[LoRA]]):
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
self.engine = engine
self.served_model_names = served_model_names
if lora_modules is None:
@ -158,7 +156,9 @@ class OpenAIServing:
})
return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]:
async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
@ -168,14 +168,16 @@ class OpenAIServing:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
return None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")
raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
self,