[Bugfix]: allow extra fields in requests to openai compatible server (#10463)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
This commit is contained in:
Guillaume Calmettes 2024-11-20 22:42:21 +01:00 committed by GitHub
parent 0cd3d9717e
commit c68f7ede6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 14 deletions

View File

@ -899,19 +899,19 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extra_fields(client: openai.AsyncOpenAI): async def test_extra_fields_allowed(client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info: resp = await client.chat.completions.create(
await client.chat.completions.create( model=MODEL_NAME,
model=MODEL_NAME, messages=[{
messages=[{ "role": "user",
"role": "system", "content": "what is 1+1?",
"content": "You are a helpful assistant.", "extra_field": "0",
"extra_field": "0", }], # type: ignore
}], # type: ignore temperature=0,
temperature=0, seed=0)
seed=0)
assert "extra_forbidden" in exc_info.value.message content = resp.choices[0].message.content
assert content is not None
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -9,12 +9,15 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams) RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__)
# torch is mocked during docs generation, # torch is mocked during docs generation,
# so we have to provide the values as literals # so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
@ -35,8 +38,19 @@ assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class OpenAIBaseModel(BaseModel): class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields # OpenAI API does allow extra fields
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="allow")
@model_validator(mode="before")
@classmethod
def __log_extra_fields__(cls, data):
if isinstance(data, dict):
extra_fields = data.keys() - cls.model_fields.keys()
if extra_fields:
logger.warning(
"The following fields were present in the request "
"but ignored: %s", extra_fields)
return data
class ErrorResponse(OpenAIBaseModel): class ErrorResponse(OpenAIBaseModel):