mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:26:15 +08:00
Fix the pydantic logging validator (#12420)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
5f671cb4c3
commit
ef001d98ef
@ -6,7 +6,8 @@ from argparse import Namespace
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
|
||||
ValidationInfo, field_validator, model_validator)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
@ -45,14 +46,14 @@ class OpenAIBaseModel(BaseModel):
|
||||
# Cache class field names
|
||||
field_names: ClassVar[Optional[Set[str]]] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def __log_extra_fields__(cls, data):
|
||||
|
||||
def __log_extra_fields__(cls, data, handler):
|
||||
result = handler(data)
|
||||
if not isinstance(data, dict):
|
||||
return result
|
||||
field_names = cls.field_names
|
||||
if field_names is None:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
# Get all class field names and their potential aliases
|
||||
field_names = set()
|
||||
for field_name, field in cls.model_fields.items():
|
||||
@ -67,7 +68,7 @@ class OpenAIBaseModel(BaseModel):
|
||||
"The following fields were present in the request "
|
||||
"but ignored: %s",
|
||||
data.keys() - field_names)
|
||||
return data
|
||||
return result
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
@ -1287,6 +1288,20 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
# The parameters of the request.
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
|
||||
|
||||
@field_validator('body', mode='plain')
|
||||
@classmethod
|
||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||
# Use url to disambiguate models
|
||||
url = info.data['url']
|
||||
if url == "/v1/chat/completions":
|
||||
return ChatCompletionRequest.model_validate(value)
|
||||
if url == "/v1/embeddings":
|
||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||
if url == "/v1/score":
|
||||
return ScoreRequest.model_validate(value)
|
||||
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
|
||||
ScoreRequest]).validate_python(value)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user