mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:25:29 +08:00
Fix the pydantic logging validator (#12420)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
5f671cb4c3
commit
ef001d98ef
@ -6,7 +6,8 @@ from argparse import Namespace
|
|||||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
|
||||||
|
ValidationInfo, field_validator, model_validator)
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
@ -45,14 +46,14 @@ class OpenAIBaseModel(BaseModel):
|
|||||||
# Cache class field names
|
# Cache class field names
|
||||||
field_names: ClassVar[Optional[Set[str]]] = None
|
field_names: ClassVar[Optional[Set[str]]] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="wrap")
|
||||||
@classmethod
|
@classmethod
|
||||||
def __log_extra_fields__(cls, data):
|
def __log_extra_fields__(cls, data, handler):
|
||||||
|
result = handler(data)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return result
|
||||||
field_names = cls.field_names
|
field_names = cls.field_names
|
||||||
if field_names is None:
|
if field_names is None:
|
||||||
if not isinstance(data, dict):
|
|
||||||
return data
|
|
||||||
# Get all class field names and their potential aliases
|
# Get all class field names and their potential aliases
|
||||||
field_names = set()
|
field_names = set()
|
||||||
for field_name, field in cls.model_fields.items():
|
for field_name, field in cls.model_fields.items():
|
||||||
@ -67,7 +68,7 @@ class OpenAIBaseModel(BaseModel):
|
|||||||
"The following fields were present in the request "
|
"The following fields were present in the request "
|
||||||
"but ignored: %s",
|
"but ignored: %s",
|
||||||
data.keys() - field_names)
|
data.keys() - field_names)
|
||||||
return data
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(OpenAIBaseModel):
|
class ErrorResponse(OpenAIBaseModel):
|
||||||
@ -1287,6 +1288,20 @@ class BatchRequestInput(OpenAIBaseModel):
|
|||||||
# The parameters of the request.
|
# The parameters of the request.
|
||||||
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
|
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
|
||||||
|
|
||||||
|
@field_validator('body', mode='plain')
|
||||||
|
@classmethod
|
||||||
|
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||||
|
# Use url to disambiguate models
|
||||||
|
url = info.data['url']
|
||||||
|
if url == "/v1/chat/completions":
|
||||||
|
return ChatCompletionRequest.model_validate(value)
|
||||||
|
if url == "/v1/embeddings":
|
||||||
|
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||||
|
if url == "/v1/score":
|
||||||
|
return ScoreRequest.model_validate(value)
|
||||||
|
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
|
||||||
|
ScoreRequest]).validate_python(value)
|
||||||
|
|
||||||
|
|
||||||
class BatchResponseData(OpenAIBaseModel):
|
class BatchResponseData(OpenAIBaseModel):
|
||||||
# HTTP status code of the response.
|
# HTTP status code of the response.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user