diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 0d44a7611aed4..a970981b75626 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer): max_tokens=10) assert len(response.choices) == 1 + + +@pytest.mark.asyncio +async def test_request_wrong_content_type(server: RemoteOpenAIServer): + + chat_input = [{"role": "user", "content": "Write a long story"}] + client = server.get_async_client() + + with pytest.raises(openai.APIStatusError): + await client.chat.completions.create( + messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_headers={ + "Content-Type": "application/x-www-form-urlencoded" + }) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 588a7781c11e6..b50a72f3a6c1a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -19,7 +19,7 @@ from http import HTTPStatus from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -252,6 +252,15 @@ async def build_async_engine_client_from_engine_args( multiprocess.mark_process_dead(engine_process.pid) +async def validate_json_request(raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + if content_type != "application/json": + raise HTTPException( + status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + detail="Unsupported Media Type: Only 'application/json' is allowed" + ) + + router = APIRouter() @@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response: return await health(raw_request) -@router.post("/tokenize") +@router.post("/tokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): assert_never(generator) -@router.post("/detokenize") +@router.post("/detokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -379,7 +388,8 @@ async def show_version(): return JSONResponse(content=ver) -@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions", + dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): @@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions") +@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) @@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings") +@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) @@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) -@router.post("/pooling") +@router.post("/pooling", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_pooling(request: PoolingRequest, raw_request: Request): handler = pooling(raw_request) @@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) -@router.post("/score") +@router.post("/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) @@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/score") +@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( @@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) -@router.post("/rerank") +@router.post("/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank(request: RerankRequest, raw_request: Request): handler = rerank(raw_request) @@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/rerank") +@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( @@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -@router.post("/v2/rerank") +@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) @@ -582,7 +592,7 @@ if envs.VLLM_SERVER_DEV_MODE: return Response(status_code=200) -@router.post("/invocations") +@router.post("/invocations", dependencies=[Depends(validate_json_request)]) async def invocations(raw_request: Request): """ For SageMaker, routes requests to other handlers based on model `task`. @@ -632,7 +642,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: "Lora dynamic loading & unloading is enabled in the API server. " "This should ONLY be used for local development!") - @router.post("/v1/load_lora_adapter") + @router.post("/v1/load_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): handler = models(raw_request) @@ -643,7 +654,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: return Response(status_code=200, content=response) - @router.post("/v1/unload_lora_adapter") + @router.post("/v1/unload_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): handler = models(raw_request)