mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
[Bugfix]: make most of test_openai_schema.py pass (#17664)
This commit is contained in:
parent
09f106a91e
commit
f25e0d1125
@ -17,8 +17,10 @@ from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import prometheus_client
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
@ -305,15 +307,18 @@ async def validate_json_request(raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
media_type = content_type.split(";", maxsplit=1)[0]
|
||||
if media_type != "application/json":
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
||||
detail="Unsupported Media Type: Only 'application/json' is allowed"
|
||||
)
|
||||
raise RequestValidationError(errors=[
|
||||
"Unsupported Media Type: Only 'application/json' is allowed"
|
||||
])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PrometheusResponse(Response):
|
||||
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
@ -332,6 +337,10 @@ def mount_metrics(app: FastAPI):
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
# instead of the default "application/json" which is incorrect.
|
||||
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
||||
Instrumentator(
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
@ -342,7 +351,7 @@ def mount_metrics(app: FastAPI):
|
||||
"/server_info",
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(app).expose(app)
|
||||
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
@ -401,11 +410,11 @@ def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(raw_request: Request) -> JSONResponse:
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await engine_client(raw_request).check_health()
|
||||
return JSONResponse(content={}, status_code=200)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/load")
|
||||
@ -427,18 +436,42 @@ async def get_server_load_metrics(request: Request):
|
||||
content={'server_load': request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.api_route("/ping", methods=["GET", "POST"])
|
||||
async def ping(raw_request: Request) -> JSONResponse:
|
||||
@router.get("/ping", response_class=Response)
|
||||
@router.post("/ping", response_class=Response)
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
|
||||
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/tokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.NOT_FOUND.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.NOT_IMPLEMENTED.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
try:
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
except NotImplementedError as e:
|
||||
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED.value,
|
||||
detail=str(e)) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
detail=str(e)) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -448,12 +481,31 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/detokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.NOT_FOUND.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
try:
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
except OverflowError as e:
|
||||
raise RequestValidationError(errors=[str(e)]) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
detail=str(e)) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -478,7 +530,23 @@ async def show_version():
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions",
|
||||
dependencies=[Depends(validate_json_request)])
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {
|
||||
"content": {
|
||||
"text/event-stream": {}
|
||||
}
|
||||
},
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.NOT_FOUND.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
}
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
@ -500,7 +568,24 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/v1/completions",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {
|
||||
"content": {
|
||||
"text/event-stream": {}
|
||||
}
|
||||
},
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.NOT_FOUND.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
@ -509,7 +594,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Completions API")
|
||||
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
try:
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
except OverflowError as e:
|
||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=str(e)) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
detail=str(e)) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -519,7 +612,16 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/v1/embeddings",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
@ -566,7 +668,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/pooling",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
@ -606,7 +717,16 @@ async def create_classify(request: ClassificationRequest,
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
@ -625,7 +745,16 @@ async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/v1/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
@ -636,12 +765,28 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v1/audio/transcriptions")
|
||||
@router.post("/v1/audio/transcriptions",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {
|
||||
"content": {
|
||||
"text/event-stream": {}
|
||||
}
|
||||
},
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
||||
Form()],
|
||||
raw_request: Request):
|
||||
async def create_transcriptions(raw_request: Request,
|
||||
request: Annotated[TranscriptionRequest,
|
||||
Form()]):
|
||||
handler = transcription(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
@ -661,7 +806,16 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
@ -679,7 +833,16 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/v1/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning_once(
|
||||
@ -690,7 +853,16 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/v2/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
@ -770,12 +942,29 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
|
||||
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
|
||||
@router.post("/invocations",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
async def invocations(raw_request: Request):
|
||||
"""
|
||||
For SageMaker, routes requests to other handlers based on model `task`.
|
||||
"""
|
||||
body = await raw_request.json()
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except JSONDecodeError as e:
|
||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}") from e
|
||||
|
||||
task = raw_request.app.state.task
|
||||
|
||||
if task not in TASK_HANDLERS:
|
||||
@ -866,10 +1055,26 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(_: Request, exc: HTTPException):
|
||||
err = ErrorResponse(message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
err = ErrorResponse(message=str(exc),
|
||||
type="BadRequestError",
|
||||
async def validation_exception_handler(_: Request,
|
||||
exc: RequestValidationError):
|
||||
exc_str = str(exc)
|
||||
errors_str = str(exc.errors())
|
||||
|
||||
if exc.errors() and errors_str and errors_str != exc_str:
|
||||
message = f"{exc_str} {errors_str}"
|
||||
else:
|
||||
message = exc_str
|
||||
|
||||
err = ErrorResponse(message=message,
|
||||
type=HTTPStatus.BAD_REQUEST.phrase,
|
||||
code=HTTPStatus.BAD_REQUEST)
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@ -197,7 +197,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except (ValueError, TypeError, RuntimeError,
|
||||
jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
request_id = "chatcmpl-" \
|
||||
f"{self._base_request_id(raw_request, request.request_id)}"
|
||||
|
||||
@ -91,7 +91,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
input_ids: list[int] = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user