Fix the pydantic logging validator (#12420)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-01-29 04:53:13 -03:00 committed by GitHub
parent 5f671cb4c3
commit ef001d98ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,8 @@ from argparse import Namespace
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator)
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@ -45,14 +46,14 @@ class OpenAIBaseModel(BaseModel):
# Cache class field names
field_names: ClassVar[Optional[Set[str]]] = None
@model_validator(mode="before")
@model_validator(mode="wrap")
@classmethod
def __log_extra_fields__(cls, data):
def __log_extra_fields__(cls, data, handler):
result = handler(data)
if not isinstance(data, dict):
return result
field_names = cls.field_names
if field_names is None:
if not isinstance(data, dict):
return data
# Get all class field names and their potential aliases
field_names = set()
for field_name, field in cls.model_fields.items():
@ -67,7 +68,7 @@ class OpenAIBaseModel(BaseModel):
"The following fields were present in the request "
"but ignored: %s",
data.keys() - field_names)
return data
return result
class ErrorResponse(OpenAIBaseModel):
@ -1287,6 +1288,20 @@ class BatchRequestInput(OpenAIBaseModel):
# The parameters of the request.
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
@field_validator('body', mode='plain')
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url = info.data['url']
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url == "/v1/score":
return ScoreRequest.model_validate(value)
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest]).validate_python(value)
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.