mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 03:54:59 +08:00
[Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355)
This commit is contained in:
parent
12628d3c78
commit
8947bc3c15
@ -8,6 +8,7 @@ py-cpuinfo
|
||||
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
fastapi
|
||||
openai
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
|
||||
@ -21,7 +21,6 @@ pytest-rerunfailures
|
||||
pytest-shard
|
||||
httpx
|
||||
einops # required for MPT
|
||||
openai
|
||||
requests
|
||||
ray
|
||||
peft
|
||||
|
||||
@ -15,6 +15,7 @@ import ray
|
||||
import requests
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||
assert loaded == {"result": 2}, loaded
|
||||
|
||||
|
||||
async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
"extra_field": "0",
|
||||
}], # type: ignore
|
||||
temperature=0,
|
||||
seed=0)
|
||||
|
||||
assert "extra_forbidden" in exc_info.value.message
|
||||
|
||||
|
||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||
simple_sql_grammar = """
|
||||
start: select_statement
|
||||
|
||||
@ -9,7 +9,7 @@ import json
|
||||
import ssl
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.serving_engine import LoRA
|
||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
@ -18,7 +18,7 @@ class LoRAParserAction(argparse.Action):
|
||||
lora_list = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRA(name, path))
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
|
||||
@ -4,14 +4,20 @@ import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does not allow extra fields
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
object: str = "error"
|
||||
message: str
|
||||
type: str
|
||||
@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
|
||||
code: int
|
||||
|
||||
|
||||
class ModelPermission(BaseModel):
|
||||
class ModelPermission(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
|
||||
is_blocking: bool = False
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
class ModelCard(OpenAIBaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
@ -44,26 +50,26 @@ class ModelCard(BaseModel):
|
||||
permission: List[ModelPermission] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
class ModelList(OpenAIBaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
class UsageInfo(OpenAIBaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
class ResponseFormat(OpenAIBaseModel):
|
||||
# type must be "json_object" or "text"
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[Dict[str, str]]
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
@ -204,7 +210,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str
|
||||
@ -343,19 +349,19 @@ class CompletionRequest(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
class LogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
class CompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
stop_reason: Union[None, int, str] = Field(
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
stop_reason: Union[None, int, str] = Field(
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
class CompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
stop_reason: Union[None, int, str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
class ChatCompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
stop_reason: Union[None, int, str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import codecs
|
||||
import time
|
||||
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
|
||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
||||
Optional, Tuple, TypedDict, Union, final)
|
||||
|
||||
from fastapi import Request
|
||||
from openai.types.chat import (ChatCompletionContentPartParam,
|
||||
ChatCompletionRole)
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -10,7 +13,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
@ -20,20 +24,41 @@ from vllm.utils import random_uuid
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@final # So that it should be compatible with Dict[str, str]
|
||||
class ConversationMessage(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
lora_modules: Optional[List[LoRA]] = None,
|
||||
chat_template=None):
|
||||
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||
chat_template: Optional[str] = None):
|
||||
super().__init__(engine=engine,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules)
|
||||
self.response_role = response_role
|
||||
self._load_chat_template(chat_template)
|
||||
|
||||
def _parse_chat_message_content(
|
||||
self,
|
||||
role: ChatCompletionRole,
|
||||
content: Optional[Union[str,
|
||||
Iterable[ChatCompletionContentPartParam]]],
|
||||
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
|
||||
if content is None:
|
||||
return [], []
|
||||
if isinstance(content, str):
|
||||
return [ConversationMessage(role=role, content=content)], []
|
||||
|
||||
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
|
||||
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
|
||||
raise NotImplementedError("Complex input not supported yet")
|
||||
|
||||
async def create_chat_completion(
|
||||
self, request: ChatCompletionRequest, raw_request: Request
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
||||
@ -52,10 +77,19 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return error_check_ret
|
||||
|
||||
try:
|
||||
conversation: List[ConversationMessage] = []
|
||||
|
||||
for m in request.messages:
|
||||
messages, _ = self._parse_chat_message_content(
|
||||
m["role"], m["content"])
|
||||
|
||||
conversation.extend(messages)
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
conversation=conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt)
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error in applying chat template from request: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
@ -105,9 +139,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
|
||||
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str) -> AsyncGenerator[str, None]:
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
@ -252,7 +285,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
final_res: RequestOutput = None
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
@ -317,7 +350,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
return response
|
||||
|
||||
def _load_chat_template(self, chat_template):
|
||||
def _load_chat_template(self, chat_template: Optional[str]):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
if chat_template is not None:
|
||||
|
||||
@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
LogProbs, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRA]] = None):
|
||||
lora_modules: Optional[List[LoRAModulePath]] = None):
|
||||
super().__init__(engine=engine,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules)
|
||||
@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time = int(time.time())
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators = []
|
||||
generators: List[AsyncIterator[RequestOutput]] = []
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
num_prompts=len(prompts))
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: RequestOutput = [None] * len(prompts)
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
|
||||
@ -22,17 +22,15 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRA:
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model_names: List[str],
|
||||
lora_modules=Optional[List[LoRA]]):
|
||||
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]]):
|
||||
self.engine = engine
|
||||
self.served_model_names = served_model_names
|
||||
if lora_modules is None:
|
||||
@ -158,7 +156,9 @@ class OpenAIServing:
|
||||
})
|
||||
return json_str
|
||||
|
||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||
async def _check_model(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
) -> Optional[ErrorResponse]:
|
||||
if request.model in self.served_model_names:
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
@ -168,14 +168,16 @@ class OpenAIServing:
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
||||
def _maybe_get_lora(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
) -> Optional[LoRARequest]:
|
||||
if request.model in self.served_model_names:
|
||||
return None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError("The model `{request.model}` does not exist.")
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user