[Bugfix] Missing Content Type returns 500 Internal Server Error (#13193)

This commit is contained in:
Vaibhav Jain 2025-02-13 20:22:22 +05:30 committed by GitHub
parent 1bc3b5e71b
commit 37dfa60037
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 15 deletions

View File

@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
max_tokens=10)
assert len(response.choices) == 1
@pytest.mark.asyncio
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
chat_input = [{"role": "user", "content": "Write a long story"}]
client = server.get_async_client()
with pytest.raises(openai.APIStatusError):
await client.chat.completions.create(
messages=chat_input,
model=MODEL_NAME,
max_tokens=10000,
extra_headers={
"Content-Type": "application/x-www-form-urlencoded"
})

View File

@ -19,7 +19,7 @@ from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@ -252,6 +252,15 @@ async def build_async_engine_client_from_engine_args(
multiprocess.mark_process_dead(engine_process.pid)
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
if content_type != "application/json":
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
)
router = APIRouter()
@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response:
return await health(raw_request)
@router.post("/tokenize")
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
assert_never(generator)
@router.post("/detokenize")
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
@ -379,7 +388,8 @@ async def show_version():
return JSONResponse(content=ver)
@router.post("/v1/chat/completions")
@router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/pooling")
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
assert_never(generator)
@router.post("/score")
@router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/score")
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/rerank")
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/rerank")
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once(
@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
@router.post("/v2/rerank")
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
@ -582,7 +592,7 @@ if envs.VLLM_SERVER_DEV_MODE:
return Response(status_code=200)
@router.post("/invocations")
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
@ -632,7 +642,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
@router.post("/v1/load_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
handler = models(raw_request)
@ -643,7 +654,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
@router.post("/v1/unload_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
handler = models(raw_request)