[Frontend] Update OpenAI error response to upstream format (#22099)

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
This commit is contained in:
Moritz Sanft 2025-08-07 08:06:00 +02:00 committed by GitHub
parent cbc8457b26
commit 370661856b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 73 additions and 67 deletions

View File

@ -121,8 +121,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
error = classification_response.json() error = classification_response.json()
assert classification_response.status_code == 400 assert classification_response.status_code == 400
assert error["object"] == "error" assert "truncate_prompt_tokens" in error["error"]["message"]
assert "truncate_prompt_tokens" in error["message"]
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@ -137,7 +136,7 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
error = classification_response.json() error = classification_response.json()
assert classification_response.status_code == 400 assert classification_response.status_code == 400
assert error["object"] == "error" assert "error" in error
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])

View File

@ -160,8 +160,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,
mock_engine.generate.assert_not_called() mock_engine.generate.assert_not_called()
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.NOT_FOUND.value assert response.error.code == HTTPStatus.NOT_FOUND.value
assert non_existent_model in response.message assert non_existent_model in response.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
@ -190,8 +190,8 @@ async def test_serving_completion_resolver_add_lora_fails(
# Assert the correct error response # Assert the correct error response
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.BAD_REQUEST.value assert response.error.code == HTTPStatus.BAD_REQUEST.value
assert invalid_model in response.message assert invalid_model in response.error.message
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -66,8 +66,8 @@ async def test_load_lora_adapter_missing_fields():
request = LoadLoRAAdapterRequest(lora_name="", lora_path="") request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.error.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST assert response.error.code == HTTPStatus.BAD_REQUEST
@pytest.mark.asyncio @pytest.mark.asyncio
@ -84,8 +84,8 @@ async def test_load_lora_adapter_duplicate():
lora_path="/path/to/adapter1") lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.error.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST assert response.error.code == HTTPStatus.BAD_REQUEST
assert len(serving_models.lora_requests) == 1 assert len(serving_models.lora_requests) == 1
@ -110,8 +110,8 @@ async def test_unload_lora_adapter_missing_fields():
request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None) request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_models.unload_lora_adapter(request) response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.error.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST assert response.error.code == HTTPStatus.BAD_REQUEST
@pytest.mark.asyncio @pytest.mark.asyncio
@ -120,5 +120,5 @@ async def test_unload_lora_adapter_not_found():
request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter") request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_models.unload_lora_adapter(request) response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "NotFoundError" assert response.error.type == "NotFoundError"
assert response.code == HTTPStatus.NOT_FOUND assert response.error.code == HTTPStatus.NOT_FOUND

View File

@ -116,8 +116,10 @@ async def test_non_asr_model(winning_call):
file=winning_call, file=winning_call,
language="en", language="en",
temperature=0.0) temperature=0.0)
assert res.code == 400 and not res.text err = res.error
assert res.message == "The model does not support Transcriptions API" assert err["code"] == 400 and not res.text
assert err[
"message"] == "The model does not support Transcriptions API"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -133,12 +135,15 @@ async def test_completion_endpoints():
"role": "system", "role": "system",
"content": "You are a helpful assistant." "content": "You are a helpful assistant."
}]) }])
assert res.code == 400 err = res.error
assert res.message == "The model does not support Chat Completions API" assert err["code"] == 400
assert err[
"message"] == "The model does not support Chat Completions API"
res = await client.completions.create(model=model_name, prompt="Hello") res = await client.completions.create(model=model_name, prompt="Hello")
assert res.code == 400 err = res.error
assert res.message == "The model does not support Completions API" assert err["code"] == 400
assert err["message"] == "The model does not support Completions API"
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -73,8 +73,9 @@ async def test_non_asr_model(foscolo):
res = await client.audio.translations.create(model=model_name, res = await client.audio.translations.create(model=model_name,
file=foscolo, file=foscolo,
temperature=0.0) temperature=0.0)
assert res.code == 400 and not res.text err = res.error
assert res.message == "The model does not support Translations API" assert err["code"] == 400 and not res.text
assert err["message"] == "The model does not support Translations API"
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -62,7 +62,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, ErrorResponse, EmbeddingResponse, ErrorInfo,
ErrorResponse,
LoadLoRAAdapterRequest, LoadLoRAAdapterRequest,
PoolingRequest, PoolingResponse, PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse, RerankRequest, RerankResponse,
@ -506,7 +507,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, TokenizeResponse): elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -540,7 +541,7 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, DetokenizeResponse): elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -556,7 +557,7 @@ def maybe_register_tokenizer_info_endpoint(args):
"""Get comprehensive tokenizer information.""" """Get comprehensive tokenizer information."""
result = await tokenization(raw_request).get_tokenizer_info() result = await tokenization(raw_request).get_tokenizer_info()
return JSONResponse(content=result.model_dump(), return JSONResponse(content=result.model_dump(),
status_code=result.code if isinstance( status_code=result.error.code if isinstance(
result, ErrorResponse) else 200) result, ErrorResponse) else 200)
@ -603,7 +604,7 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ResponsesResponse): elif isinstance(generator, ResponsesResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@ -620,7 +621,7 @@ async def retrieve_responses(response_id: str, raw_request: Request):
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return JSONResponse(content=response.model_dump()) return JSONResponse(content=response.model_dump())
@ -635,7 +636,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return JSONResponse(content=response.model_dump()) return JSONResponse(content=response.model_dump())
@ -670,7 +671,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ChatCompletionResponse): elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -715,7 +716,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, CompletionResponse): elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -744,7 +745,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, EmbeddingResponse): elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -772,7 +773,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
generator = await handler.create_pooling(request, raw_request) generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, PoolingResponse): elif isinstance(generator, PoolingResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -792,7 +793,7 @@ async def create_classify(request: ClassificationRequest,
generator = await handler.create_classify(request, raw_request) generator = await handler.create_classify(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ClassificationResponse): elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -821,7 +822,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
generator = await handler.create_score(request, raw_request) generator = await handler.create_score(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ScoreResponse): elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -881,7 +882,7 @@ async def create_transcriptions(raw_request: Request,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, TranscriptionResponse): elif isinstance(generator, TranscriptionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -922,7 +923,7 @@ async def create_translations(request: Annotated[TranslationRequest,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, TranslationResponse): elif isinstance(generator, TranslationResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -950,7 +951,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
generator = await handler.do_rerank(request, raw_request) generator = await handler.do_rerank(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, RerankResponse): elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@ -1175,7 +1176,7 @@ async def invocations(raw_request: Request):
msg = ("Cannot find suitable handler for request. " msg = ("Cannot find suitable handler for request. "
f"Expected one of: {type_names}") f"Expected one of: {type_names}")
res = base(raw_request).create_error_response(message=msg) res = base(raw_request).create_error_response(message=msg)
return JSONResponse(content=res.model_dump(), status_code=res.code) return JSONResponse(content=res.model_dump(), status_code=res.error.code)
if envs.VLLM_TORCH_PROFILER_DIR: if envs.VLLM_TORCH_PROFILER_DIR:
@ -1211,7 +1212,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
response = await handler.load_lora_adapter(request) response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
@ -1223,7 +1224,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
response = await handler.unload_lora_adapter(request) response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
@ -1502,9 +1503,10 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(_: Request, exc: HTTPException): async def http_exception_handler(_: Request, exc: HTTPException):
err = ErrorResponse(message=exc.detail, err = ErrorResponse(
error=ErrorInfo(message=exc.detail,
type=HTTPStatus(exc.status_code).phrase, type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code) code=exc.status_code))
return JSONResponse(err.model_dump(), status_code=exc.status_code) return JSONResponse(err.model_dump(), status_code=exc.status_code)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
@ -1518,9 +1520,9 @@ def build_app(args: Namespace) -> FastAPI:
else: else:
message = exc_str message = exc_str
err = ErrorResponse(message=message, err = ErrorResponse(error=ErrorInfo(message=message,
type=HTTPStatus.BAD_REQUEST.phrase, type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST) code=HTTPStatus.BAD_REQUEST))
return JSONResponse(err.model_dump(), return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)

View File

@ -78,14 +78,17 @@ class OpenAIBaseModel(BaseModel):
return result return result
class ErrorResponse(OpenAIBaseModel): class ErrorInfo(OpenAIBaseModel):
object: str = "error"
message: str message: str
type: str type: str
param: Optional[str] = None param: Optional[str] = None
code: int code: int
class ErrorResponse(OpenAIBaseModel):
error: ErrorInfo
class ModelPermission(OpenAIBaseModel): class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission" object: str = "model_permission"

View File

@ -302,7 +302,7 @@ async def run_request(serving_engine_func: Callable,
id=f"vllm-{random_uuid()}", id=f"vllm-{random_uuid()}",
custom_id=request.custom_id, custom_id=request.custom_id,
response=BatchResponseData( response=BatchResponseData(
status_code=response.code, status_code=response.error.code,
request_id=f"vllm-batch-{random_uuid()}"), request_id=f"vllm-batch-{random_uuid()}"),
error=response, error=response,
) )

View File

@ -47,10 +47,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, ErrorResponse, EmbeddingResponse, ErrorInfo,
PoolingResponse, RerankRequest, ErrorResponse, PoolingResponse,
ResponsesRequest, ScoreRequest, RerankRequest, ResponsesRequest,
ScoreResponse, ScoreRequest, ScoreResponse,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
@ -412,21 +412,18 @@ class OpenAIServing:
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message, return ErrorResponse(error=ErrorInfo(
type=err_type, message=message, type=err_type, code=status_code.value))
code=status_code.value)
def create_streaming_error_response( def create_streaming_error_response(
self, self,
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({ json_str = json.dumps(
"error":
self.create_error_response(message=message, self.create_error_response(message=message,
err_type=err_type, err_type=err_type,
status_code=status_code).model_dump() status_code=status_code).model_dump())
})
return json_str return json_str
async def _check_model( async def _check_model(
@ -445,7 +442,7 @@ class OpenAIServing:
if isinstance(load_result, LoRARequest): if isinstance(load_result, LoRARequest):
return None return None
if isinstance(load_result, ErrorResponse) and \ if isinstance(load_result, ErrorResponse) and \
load_result.code == HTTPStatus.BAD_REQUEST.value: load_result.error.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result error_response = load_result
return error_response or self.create_error_response( return error_response or self.create_error_response(

View File

@ -9,7 +9,7 @@ from typing import Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse, from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse,
LoadLoRAAdapterRequest, LoadLoRAAdapterRequest,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission, ModelPermission,
@ -82,7 +82,7 @@ class OpenAIServingModels:
load_result = await self.load_lora_adapter( load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name) request=load_request, base_model_name=lora.base_model_name)
if isinstance(load_result, ErrorResponse): if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message) raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool: def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths) return any(model.name == model_name for model in self.base_model_paths)
@ -284,6 +284,5 @@ def create_error_response(
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message, return ErrorResponse(error=ErrorInfo(
type=err_type, message=message, type=err_type, code=status_code.value))
code=status_code.value)