Fix the pydantic logging validator (#12420)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-01-29 04:53:13 -03:00 committed by GitHub
parent 5f671cb4c3
commit ef001d98ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,8 @@ from argparse import Namespace
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator)
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@ -45,14 +46,14 @@ class OpenAIBaseModel(BaseModel):
# Cache class field names # Cache class field names
field_names: ClassVar[Optional[Set[str]]] = None field_names: ClassVar[Optional[Set[str]]] = None
@model_validator(mode="before") @model_validator(mode="wrap")
@classmethod @classmethod
def __log_extra_fields__(cls, data): def __log_extra_fields__(cls, data, handler):
result = handler(data)
if not isinstance(data, dict):
return result
field_names = cls.field_names field_names = cls.field_names
if field_names is None: if field_names is None:
if not isinstance(data, dict):
return data
# Get all class field names and their potential aliases # Get all class field names and their potential aliases
field_names = set() field_names = set()
for field_name, field in cls.model_fields.items(): for field_name, field in cls.model_fields.items():
@ -67,7 +68,7 @@ class OpenAIBaseModel(BaseModel):
"The following fields were present in the request " "The following fields were present in the request "
"but ignored: %s", "but ignored: %s",
data.keys() - field_names) data.keys() - field_names)
return data return result
class ErrorResponse(OpenAIBaseModel): class ErrorResponse(OpenAIBaseModel):
@ -1287,6 +1288,20 @@ class BatchRequestInput(OpenAIBaseModel):
# The parameters of the request. # The parameters of the request.
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
@field_validator('body', mode='plain')
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url = info.data['url']
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url == "/v1/score":
return ScoreRequest.model_validate(value)
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest]).validate_python(value)
class BatchResponseData(OpenAIBaseModel): class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response. # HTTP status code of the response.