mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-22 17:34:28 +08:00
[OpenAI] Add parameter metadata to validation errors (#30134)
Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
parent
23daef548d
commit
769f27e701
@ -909,6 +909,16 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_: Request, exc: RequestValidationError):
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
|
||||
param = None
|
||||
for error in exc.errors():
|
||||
if "ctx" in error and "error" in error["ctx"]:
|
||||
ctx_error = error["ctx"]["error"]
|
||||
if isinstance(ctx_error, VLLMValidationError):
|
||||
param = ctx_error.parameter
|
||||
break
|
||||
|
||||
exc_str = str(exc)
|
||||
errors_str = str(exc.errors())
|
||||
|
||||
@ -922,6 +932,7 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
message=message,
|
||||
type=HTTPStatus.BAD_REQUEST.phrase,
|
||||
code=HTTPStatus.BAD_REQUEST,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@ -131,6 +131,36 @@ class ErrorResponse(OpenAIBaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class VLLMValidationError(ValueError):
|
||||
"""vLLM-specific validation error for request validation failures.
|
||||
|
||||
Args:
|
||||
message: The error message describing the validation failure.
|
||||
parameter: Optional parameter name that failed validation.
|
||||
value: Optional value that was rejected during validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
parameter: str | None = None,
|
||||
value: Any = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.parameter = parameter
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
base = super().__str__()
|
||||
extras = []
|
||||
if self.parameter is not None:
|
||||
extras.append(f"parameter={self.parameter}")
|
||||
if self.value is not None:
|
||||
extras.append(f"value={self.value}")
|
||||
return f"{base} ({', '.join(extras)})" if extras else base
|
||||
|
||||
|
||||
class ModelPermission(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
@ -466,7 +496,9 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
@model_validator(mode="before")
|
||||
def validate_prompt(cls, data):
|
||||
if data.get("prompt") is not None:
|
||||
raise ValueError("prompt template is not supported")
|
||||
raise VLLMValidationError(
|
||||
"prompt template is not supported", parameter="prompt"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -850,7 +882,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter="stream_options",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -859,19 +894,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`."
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.",
|
||||
parameter="prompt_logprobs",
|
||||
)
|
||||
|
||||
if prompt_logprobs < 0 and prompt_logprobs != -1:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` must be a positive value or -1.",
|
||||
parameter="prompt_logprobs",
|
||||
value=prompt_logprobs,
|
||||
)
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0 and top_logprobs != -1:
|
||||
raise ValueError("`top_logprobs` must be a positive value or -1.")
|
||||
raise VLLMValidationError(
|
||||
"`top_logprobs` must be a positive value or -1.",
|
||||
parameter="top_logprobs",
|
||||
value=top_logprobs,
|
||||
)
|
||||
|
||||
if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
raise VLLMValidationError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true.",
|
||||
parameter="top_logprobs",
|
||||
)
|
||||
|
||||
return data
|
||||
@ -1285,9 +1330,10 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
for k in ("json", "regex", "choice")
|
||||
)
|
||||
if count > 1:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
"You can only use one kind of constraints for structured "
|
||||
"outputs ('json', 'regex' or 'choice')."
|
||||
"outputs ('json', 'regex' or 'choice').",
|
||||
parameter="structured_outputs",
|
||||
)
|
||||
return data
|
||||
|
||||
@ -1296,14 +1342,23 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`."
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.",
|
||||
parameter="prompt_logprobs",
|
||||
)
|
||||
|
||||
if prompt_logprobs < 0 and prompt_logprobs != -1:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` must be a positive value or -1.",
|
||||
parameter="prompt_logprobs",
|
||||
value=prompt_logprobs,
|
||||
)
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise ValueError("`logprobs` must be a positive value.")
|
||||
raise VLLMValidationError(
|
||||
"`logprobs` must be a positive value.",
|
||||
parameter="logprobs",
|
||||
value=logprobs,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -1311,7 +1366,10 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter="stream_options",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -2138,7 +2196,15 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -2351,7 +2417,15 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@ -417,8 +417,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
assert len(generators) == 1
|
||||
(result_generator,) = generators
|
||||
@ -448,8 +447,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
@ -682,7 +680,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_parsers = [None] * num_choices
|
||||
except Exception as e:
|
||||
logger.exception("Error in tool parser creation.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
@ -1328,9 +1326,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
@ -1354,8 +1351,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
GenerationError,
|
||||
@ -247,8 +248,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
@ -308,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
@ -510,9 +509,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@ -660,8 +658,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token = f"token_id:{token_id}"
|
||||
else:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
raise VLLMValidationError(
|
||||
"Unable to get tokenizer because "
|
||||
"`skip_tokenizer_init=True`",
|
||||
parameter="skip_tokenizer_init",
|
||||
value=True,
|
||||
)
|
||||
|
||||
token = tokenizer.decode(token_id)
|
||||
@ -720,6 +721,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request: CompletionRequest,
|
||||
max_input_length: int | None = None,
|
||||
) -> RenderConfig:
|
||||
# Validate max_tokens before using it
|
||||
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
|
||||
f"the model's maximum context length ({self.max_model_len}).",
|
||||
parameter="max_tokens",
|
||||
value=request.max_tokens,
|
||||
)
|
||||
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
|
||||
return RenderConfig(
|
||||
max_length=max_input_tokens_len,
|
||||
|
||||
@ -57,6 +57,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
@ -322,8 +323,10 @@ class OpenAIServing:
|
||||
input_processor = self.input_processor
|
||||
tokenizer = input_processor.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot use beam search when `skip_tokenizer_init=True`"
|
||||
raise VLLMValidationError(
|
||||
"You cannot use beam search when `skip_tokenizer_init=True`",
|
||||
parameter="skip_tokenizer_init",
|
||||
value=True,
|
||||
)
|
||||
|
||||
eos_token_id: int = tokenizer.eos_token_id # type: ignore
|
||||
@ -706,8 +709,7 @@ class OpenAIServing:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
@ -738,14 +740,43 @@ class OpenAIServing:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
message: str | Exception,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
param: str | None = None,
|
||||
) -> ErrorResponse:
|
||||
exc: Exception | None = None
|
||||
|
||||
if isinstance(message, Exception):
|
||||
exc = message
|
||||
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
|
||||
if isinstance(exc, VLLMValidationError):
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = exc.parameter
|
||||
elif isinstance(exc, (ValueError, TypeError, RuntimeError)):
|
||||
# Common validation errors from user input
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
elif exc.__class__.__name__ == "TemplateError":
|
||||
# jinja2.TemplateError (avoid importing jinja2)
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
else:
|
||||
err_type = "InternalServerError"
|
||||
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
param = None
|
||||
|
||||
message = str(exc)
|
||||
|
||||
if self.log_error_stack:
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type is not None:
|
||||
@ -753,18 +784,27 @@ class OpenAIServing:
|
||||
else:
|
||||
traceback.print_stack()
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
|
||||
error=ErrorInfo(
|
||||
message=message,
|
||||
type=err_type,
|
||||
code=status_code.value,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
message: str | Exception,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
param: str | None = None,
|
||||
) -> str:
|
||||
json_str = json.dumps(
|
||||
self.create_error_response(
|
||||
message=message, err_type=err_type, status_code=status_code
|
||||
message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code,
|
||||
param=param,
|
||||
).model_dump()
|
||||
)
|
||||
return json_str
|
||||
@ -825,6 +865,7 @@ class OpenAIServing:
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="model",
|
||||
)
|
||||
|
||||
def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
|
||||
@ -991,11 +1032,13 @@ class OpenAIServing:
|
||||
ClassificationChatRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input."
|
||||
f"Please reduce the length of the input.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
@ -1017,20 +1060,24 @@ class OpenAIServing:
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages."
|
||||
"the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
|
||||
if max_tokens is not None and token_num + max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
"'max_tokens' or 'max_completion_tokens' is too large: "
|
||||
f"{max_tokens}. This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens and your request has "
|
||||
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
|
||||
f" - {token_num})."
|
||||
f" - {token_num}).",
|
||||
parameter="max_tokens",
|
||||
value=max_tokens,
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
@ -94,6 +94,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ResponsesResponse,
|
||||
ResponseUsage,
|
||||
StreamingResponsesResponse,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
GenerationError,
|
||||
@ -271,6 +272,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
err_type="invalid_request_error",
|
||||
message=error_message,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="input",
|
||||
)
|
||||
return None
|
||||
|
||||
@ -282,6 +284,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
err_type="invalid_request_error",
|
||||
message="logprobs are not supported with gpt-oss models",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="logprobs",
|
||||
)
|
||||
if request.store and not self.enable_store and request.background:
|
||||
return self.create_error_response(
|
||||
@ -294,6 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
"the vLLM server."
|
||||
),
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="background",
|
||||
)
|
||||
if request.previous_input_messages and request.previous_response_id:
|
||||
return self.create_error_response(
|
||||
@ -301,6 +305,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
message="Only one of `previous_input_messages` and "
|
||||
"`previous_response_id` can be set.",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="previous_response_id",
|
||||
)
|
||||
return None
|
||||
|
||||
@ -457,8 +462,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
assert len(generators) == 1
|
||||
(result_generator,) = generators
|
||||
@ -546,7 +550,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
@ -630,8 +634,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
# NOTE: Implementation of stauts is still WIP, but for now
|
||||
# we guarantee that if the status is not "completed", it is accurate.
|
||||
@ -1074,7 +1077,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
response = self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
logger.exception("Background request failed for %s", request.request_id)
|
||||
response = self.create_error_response(str(e))
|
||||
response = self.create_error_response(e)
|
||||
finally:
|
||||
new_event_signal.set()
|
||||
|
||||
@ -1099,7 +1102,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
response = self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
logger.exception("Background request failed for %s", request.request_id)
|
||||
response = self.create_error_response(str(e))
|
||||
response = self.create_error_response(e)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
# If the request has failed, update the status to "failed".
|
||||
@ -1116,7 +1119,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
starting_after: int | None = None,
|
||||
) -> AsyncGenerator[StreamingResponsesResponse, None]:
|
||||
if response_id not in self.event_store:
|
||||
raise ValueError(f"Unknown response_id: {response_id}")
|
||||
raise VLLMValidationError(
|
||||
f"Unknown response_id: {response_id}",
|
||||
parameter="response_id",
|
||||
value=response_id,
|
||||
)
|
||||
|
||||
event_deque, new_event_signal = self.event_store[response_id]
|
||||
start_index = 0 if starting_after is None else starting_after + 1
|
||||
@ -1172,6 +1179,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
err_type="invalid_request_error",
|
||||
message="Cannot cancel a synchronous response.",
|
||||
param="response_id",
|
||||
)
|
||||
|
||||
# Update the status to "cancelled".
|
||||
@ -1191,6 +1199,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
err_type="invalid_request_error",
|
||||
message=f"Response with id '{response_id}' not found.",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="response_id",
|
||||
)
|
||||
|
||||
def _make_store_not_supported_error(self) -> ErrorResponse:
|
||||
@ -1203,6 +1212,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
"starting the vLLM server."
|
||||
),
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="store",
|
||||
)
|
||||
|
||||
async def _process_simple_streaming_events(
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
UsageInfo,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@ -259,7 +260,11 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
raise VLLMValidationError(
|
||||
"Maximum file size exceeded",
|
||||
parameter="audio_filesize_mb",
|
||||
value=len(audio_data) / 1024**2,
|
||||
)
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
@ -287,12 +292,18 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
if request.response_format == "verbose_json":
|
||||
if not isinstance(prompt, dict):
|
||||
raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}")
|
||||
raise VLLMValidationError(
|
||||
"Expected prompt to be a dict",
|
||||
parameter="prompt",
|
||||
value=type(prompt).__name__,
|
||||
)
|
||||
prompt_dict = cast(dict, prompt)
|
||||
decoder_prompt = prompt.get("decoder_prompt")
|
||||
if not isinstance(decoder_prompt, str):
|
||||
raise ValueError(
|
||||
f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}"
|
||||
raise VLLMValidationError(
|
||||
"Expected decoder_prompt to be str",
|
||||
parameter="decoder_prompt",
|
||||
value=type(decoder_prompt).__name__,
|
||||
)
|
||||
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
|
||||
"<|notimestamps|>", "<|0.00|>"
|
||||
@ -412,7 +423,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
try:
|
||||
@ -448,8 +459,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
for i, prompt in enumerate(prompts)
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(
|
||||
@ -523,8 +533,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
@ -634,9 +643,8 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@ -12,6 +12,7 @@ import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@ -162,8 +163,9 @@ class BaseRenderer(ABC):
|
||||
) -> list[EmbedsPrompt]:
|
||||
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
|
||||
raise VLLMValidationError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
|
||||
parameter="prompt_embeds",
|
||||
)
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
@ -396,10 +398,12 @@ class CompletionRenderer(BaseRenderer):
|
||||
) -> TokensPrompt:
|
||||
"""Create validated TokensPrompt."""
|
||||
if max_length is not None and len(token_ids) > max_length:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is {max_length} tokens. "
|
||||
f"However, your request has {len(token_ids)} input tokens. "
|
||||
"Please reduce the length of the input messages."
|
||||
"Please reduce the length of the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=len(token_ids),
|
||||
)
|
||||
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user