[Bugfix]: make most of test_openai_schema.py pass (#17664)

This commit is contained in:
David Xia 2025-05-14 20:04:35 -04:00 committed by GitHub
parent 09f106a91e
commit f25e0d1125
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 240 additions and 35 deletions

View File

@ -17,8 +17,10 @@ from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from http import HTTPStatus from http import HTTPStatus
from json import JSONDecodeError
from typing import Annotated, Optional, Union from typing import Annotated, Optional, Union
import prometheus_client
import uvloop import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@ -305,15 +307,18 @@ async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower() content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0] media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json": if media_type != "application/json":
raise HTTPException( raise RequestValidationError(errors=[
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type: Only 'application/json' is allowed"
detail="Unsupported Media Type: Only 'application/json' is allowed" ])
)
router = APIRouter() router = APIRouter()
class PrometheusResponse(Response):
media_type = prometheus_client.CONTENT_TYPE_LATEST
def mount_metrics(app: FastAPI): def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing. # Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
@ -332,6 +337,10 @@ def mount_metrics(app: FastAPI):
registry = CollectorRegistry() registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry) multiprocess.MultiProcessCollector(registry)
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator( Instrumentator(
excluded_handlers=[ excluded_handlers=[
"/metrics", "/metrics",
@ -342,7 +351,7 @@ def mount_metrics(app: FastAPI):
"/server_info", "/server_info",
], ],
registry=registry, registry=registry,
).add().instrument(app).expose(app) ).add().instrument(app).expose(app, response_class=PrometheusResponse)
# Add prometheus asgi middleware to route /metrics requests # Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
@ -401,11 +410,11 @@ def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client return request.app.state.engine_client
@router.get("/health") @router.get("/health", response_class=Response)
async def health(raw_request: Request) -> JSONResponse: async def health(raw_request: Request) -> Response:
"""Health check.""" """Health check."""
await engine_client(raw_request).check_health() await engine_client(raw_request).check_health()
return JSONResponse(content={}, status_code=200) return Response(status_code=200)
@router.get("/load") @router.get("/load")
@ -427,18 +436,42 @@ async def get_server_load_metrics(request: Request):
content={'server_load': request.app.state.server_load_metrics}) content={'server_load': request.app.state.server_load_metrics})
@router.api_route("/ping", methods=["GET", "POST"]) @router.get("/ping", response_class=Response)
async def ping(raw_request: Request) -> JSONResponse: @router.post("/ping", response_class=Response)
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker""" """Ping check. Endpoint required for SageMaker"""
return await health(raw_request) return await health(raw_request)
@router.post("/tokenize", dependencies=[Depends(validate_json_request)]) @router.post("/tokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_IMPLEMENTED.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request): async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
generator = await handler.create_tokenize(request, raw_request) try:
generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED.value,
detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -448,12 +481,31 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/detokenize", dependencies=[Depends(validate_json_request)]) @router.post("/detokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request): async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
generator = await handler.create_detokenize(request, raw_request) try:
generator = await handler.create_detokenize(request, raw_request)
except OverflowError as e:
raise RequestValidationError(errors=[str(e)]) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -478,7 +530,23 @@ async def show_version():
@router.post("/v1/chat/completions", @router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)]) dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {
"content": {
"text/event-stream": {}
}
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
}
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
@ -500,7 +568,24 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) @router.post("/v1/completions",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {
"content": {
"text/event-stream": {}
}
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
@ -509,7 +594,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return base(raw_request).create_error_response( return base(raw_request).create_error_response(
message="The model does not support Completions API") message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request) try:
generator = await handler.create_completion(request, raw_request)
except OverflowError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -519,7 +612,16 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) @router.post("/v1/embeddings",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@ -566,7 +668,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/pooling", dependencies=[Depends(validate_json_request)]) @router.post("/pooling",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_pooling(request: PoolingRequest, raw_request: Request): async def create_pooling(request: PoolingRequest, raw_request: Request):
@ -606,7 +717,16 @@ async def create_classify(request: ClassificationRequest,
assert_never(generator) assert_never(generator)
@router.post("/score", dependencies=[Depends(validate_json_request)]) @router.post("/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_score(request: ScoreRequest, raw_request: Request): async def create_score(request: ScoreRequest, raw_request: Request):
@ -625,7 +745,16 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) @router.post("/v1/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_score_v1(request: ScoreRequest, raw_request: Request): async def create_score_v1(request: ScoreRequest, raw_request: Request):
@ -636,12 +765,28 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request) return await create_score(request, raw_request)
@router.post("/v1/audio/transcriptions") @router.post("/v1/audio/transcriptions",
responses={
HTTPStatus.OK.value: {
"content": {
"text/event-stream": {}
}
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_transcriptions(request: Annotated[TranscriptionRequest, async def create_transcriptions(raw_request: Request,
Form()], request: Annotated[TranscriptionRequest,
raw_request: Request): Form()]):
handler = transcription(raw_request) handler = transcription(raw_request)
if handler is None: if handler is None:
return base(raw_request).create_error_response( return base(raw_request).create_error_response(
@ -661,7 +806,16 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)]) @router.post("/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def do_rerank(request: RerankRequest, raw_request: Request): async def do_rerank(request: RerankRequest, raw_request: Request):
@ -679,7 +833,16 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)]) @router.post("/v1/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request): async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once( logger.warning_once(
@ -690,7 +853,16 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request) return await do_rerank(request, raw_request)
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)]) @router.post("/v2/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation @with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request): async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request) return await do_rerank(request, raw_request)
@ -770,12 +942,29 @@ if envs.VLLM_SERVER_DEV_MODE:
return JSONResponse(content={"is_sleeping": is_sleeping}) return JSONResponse(content={"is_sleeping": is_sleeping})
@router.post("/invocations", dependencies=[Depends(validate_json_request)]) @router.post("/invocations",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
async def invocations(raw_request: Request): async def invocations(raw_request: Request):
""" """
For SageMaker, routes requests to other handlers based on model `task`. For SageMaker, routes requests to other handlers based on model `task`.
""" """
body = await raw_request.json() try:
body = await raw_request.json()
except JSONDecodeError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}") from e
task = raw_request.app.state.task task = raw_request.app.state.task
if task not in TASK_HANDLERS: if task not in TASK_HANDLERS:
@ -866,10 +1055,26 @@ def build_app(args: Namespace) -> FastAPI:
allow_headers=args.allowed_headers, allow_headers=args.allowed_headers,
) )
@app.exception_handler(HTTPException)
async def http_exception_handler(_: Request, exc: HTTPException):
err = ErrorResponse(message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc): async def validation_exception_handler(_: Request,
err = ErrorResponse(message=str(exc), exc: RequestValidationError):
type="BadRequestError", exc_str = str(exc)
errors_str = str(exc.errors())
if exc.errors() and errors_str and errors_str != exc_str:
message = f"{exc_str} {errors_str}"
else:
message = exc_str
err = ErrorResponse(message=message,
type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST) code=HTTPStatus.BAD_REQUEST)
return JSONResponse(err.model_dump(), return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)

View File

@ -197,7 +197,7 @@ class OpenAIServingChat(OpenAIServing):
except (ValueError, TypeError, RuntimeError, except (ValueError, TypeError, RuntimeError,
jinja2.TemplateError) as e: jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(f"{e} {e.__cause__}")
request_id = "chatcmpl-" \ request_id = "chatcmpl-" \
f"{self._base_request_id(raw_request, request.request_id)}" f"{self._base_request_id(raw_request, request.request_id)}"

View File

@ -91,7 +91,7 @@ class OpenAIServingTokenization(OpenAIServing):
) )
except (ValueError, TypeError, jinja2.TemplateError) as e: except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = [] input_ids: list[int] = []
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):