[Bugfix]: make most of test_openai_schema.py pass (#17664)

This commit is contained in:
David Xia 2025-05-14 20:04:35 -04:00 committed by GitHub
parent 09f106a91e
commit f25e0d1125
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 240 additions and 35 deletions

View File

@ -17,8 +17,10 @@ from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from json import JSONDecodeError
from typing import Annotated, Optional, Union
import prometheus_client
import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
@ -305,15 +307,18 @@ async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
)
raise RequestValidationError(errors=[
"Unsupported Media Type: Only 'application/json' is allowed"
])
router = APIRouter()
class PrometheusResponse(Response):
media_type = prometheus_client.CONTENT_TYPE_LATEST
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
@ -332,6 +337,10 @@ def mount_metrics(app: FastAPI):
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator(
excluded_handlers=[
"/metrics",
@ -342,7 +351,7 @@ def mount_metrics(app: FastAPI):
"/server_info",
],
registry=registry,
).add().instrument(app).expose(app)
).add().instrument(app).expose(app, response_class=PrometheusResponse)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
@ -401,11 +410,11 @@ def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health")
async def health(raw_request: Request) -> JSONResponse:
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
await engine_client(raw_request).check_health()
return JSONResponse(content={}, status_code=200)
return Response(status_code=200)
@router.get("/load")
@ -427,18 +436,42 @@ async def get_server_load_metrics(request: Request):
content={'server_load': request.app.state.server_load_metrics})
@router.api_route("/ping", methods=["GET", "POST"])
async def ping(raw_request: Request) -> JSONResponse:
@router.get("/ping", response_class=Response)
@router.post("/ping", response_class=Response)
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
@router.post("/tokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_IMPLEMENTED.value: {
"model": ErrorResponse
},
})
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_tokenize(request, raw_request)
try:
generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED.value,
detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@ -448,12 +481,31 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
assert_never(generator)
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
@router.post("/detokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_detokenize(request, raw_request)
try:
generator = await handler.create_detokenize(request, raw_request)
except OverflowError as e:
raise RequestValidationError(errors=[str(e)]) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@ -478,7 +530,23 @@ async def show_version():
@router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)])
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {
"content": {
"text/event-stream": {}
}
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
}
})
@with_cancellation
@load_aware_call
async def create_chat_completion(request: ChatCompletionRequest,
@ -500,7 +568,24 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
@router.post("/v1/completions",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {
"content": {
"text/event-stream": {}
}
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request):
@ -509,7 +594,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return base(raw_request).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
try:
generator = await handler.create_completion(request, raw_request)
except OverflowError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@ -519,7 +612,16 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
@router.post("/v1/embeddings",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@ -566,7 +668,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
@router.post("/pooling",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
async def create_pooling(request: PoolingRequest, raw_request: Request):
@ -606,7 +717,16 @@ async def create_classify(request: ClassificationRequest,
assert_never(generator)
@router.post("/score", dependencies=[Depends(validate_json_request)])
@router.post("/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
async def create_score(request: ScoreRequest, raw_request: Request):
@ -625,7 +745,16 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
@router.post("/v1/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
async def create_score_v1(request: ScoreRequest, raw_request: Request):
@ -636,12 +765,28 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/v1/audio/transcriptions")
@router.post("/v1/audio/transcriptions",
responses={
HTTPStatus.OK.value: {
"content": {
"text/event-stream": {}
}
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
async def create_transcriptions(request: Annotated[TranscriptionRequest,
Form()],
raw_request: Request):
async def create_transcriptions(raw_request: Request,
request: Annotated[TranscriptionRequest,
Form()]):
handler = transcription(raw_request)
if handler is None:
return base(raw_request).create_error_response(
@ -661,7 +806,16 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@router.post("/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
async def do_rerank(request: RerankRequest, raw_request: Request):
@ -679,7 +833,16 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
@router.post("/v1/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once(
@ -690,7 +853,16 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
@router.post("/v2/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
@ -770,12 +942,29 @@ if envs.VLLM_SERVER_DEV_MODE:
return JSONResponse(content={"is_sleeping": is_sleeping})
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
@router.post("/invocations",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
body = await raw_request.json()
try:
body = await raw_request.json()
except JSONDecodeError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}") from e
task = raw_request.app.state.task
if task not in TASK_HANDLERS:
@ -866,10 +1055,26 @@ def build_app(args: Namespace) -> FastAPI:
allow_headers=args.allowed_headers,
)
@app.exception_handler(HTTPException)
async def http_exception_handler(_: Request, exc: HTTPException):
err = ErrorResponse(message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = ErrorResponse(message=str(exc),
type="BadRequestError",
async def validation_exception_handler(_: Request,
exc: RequestValidationError):
exc_str = str(exc)
errors_str = str(exc.errors())
if exc.errors() and errors_str and errors_str != exc_str:
message = f"{exc_str} {errors_str}"
else:
message = exc_str
err = ErrorResponse(message=message,
type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST)
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)

View File

@ -197,7 +197,7 @@ class OpenAIServingChat(OpenAIServing):
except (ValueError, TypeError, RuntimeError,
jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
return self.create_error_response(f"{e} {e.__cause__}")
request_id = "chatcmpl-" \
f"{self._base_request_id(raw_request, request.request_id)}"

View File

@ -91,7 +91,7 @@ class OpenAIServingTokenization(OpenAIServing):
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = []
for i, engine_prompt in enumerate(engine_prompts):