diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2bc136cc48038..29d071ce50c8e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,8 @@ from argparse import Namespace from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union import torch -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, + ValidationInfo, field_validator, model_validator) from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -45,14 +46,14 @@ class OpenAIBaseModel(BaseModel): # Cache class field names field_names: ClassVar[Optional[Set[str]]] = None - @model_validator(mode="before") + @model_validator(mode="wrap") @classmethod - def __log_extra_fields__(cls, data): - + def __log_extra_fields__(cls, data, handler): + result = handler(data) + if not isinstance(data, dict): + return result field_names = cls.field_names if field_names is None: - if not isinstance(data, dict): - return data # Get all class field names and their potential aliases field_names = set() for field_name, field in cls.model_fields.items(): @@ -67,7 +68,7 @@ class OpenAIBaseModel(BaseModel): "The following fields were present in the request " "but ignored: %s", data.keys() - field_names) - return data + return result class ErrorResponse(OpenAIBaseModel): @@ -1287,6 +1288,20 @@ class BatchRequestInput(OpenAIBaseModel): # The parameters of the request. body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] + @field_validator('body', mode='plain') + @classmethod + def check_type_for_url(cls, value: Any, info: ValidationInfo): + # Use url to disambiguate models + url = info.data['url'] + if url == "/v1/chat/completions": + return ChatCompletionRequest.model_validate(value) + if url == "/v1/embeddings": + return TypeAdapter(EmbeddingRequest).validate_python(value) + if url == "/v1/score": + return ScoreRequest.model_validate(value) + return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest, + ScoreRequest]).validate_python(value) + class BatchResponseData(OpenAIBaseModel): # HTTP status code of the response.