mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 21:25:01 +08:00
[responsesAPI] add better error messaging for long prompts (#25724)
Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
c1ffcb55da
commit
831b124151
@ -8,9 +8,10 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from vllm.entrypoints.context import ConversationContext
|
from vllm.entrypoints.context import ConversationContext
|
||||||
from vllm.entrypoints.openai.protocol import ResponsesRequest
|
from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest
|
||||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||||
from vllm.entrypoints.tool_server import ToolServer
|
from vllm.entrypoints.tool_server import ToolServer
|
||||||
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
|
|
||||||
|
|
||||||
class MockConversationContext(ConversationContext):
|
class MockConversationContext(ConversationContext):
|
||||||
@ -127,3 +128,63 @@ class TestInitializeToolSessions:
|
|||||||
|
|
||||||
# Verify that init_tool_sessions was called
|
# Verify that init_tool_sessions was called
|
||||||
assert mock_context.init_tool_sessions_called
|
assert mock_context.init_tool_sessions_called
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateGeneratorInput:
|
||||||
|
"""Test class for _validate_generator_input method"""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def serving_responses_instance(self):
|
||||||
|
"""Create a real OpenAIServingResponses instance for testing"""
|
||||||
|
# Create minimal mocks for required dependencies
|
||||||
|
engine_client = MagicMock()
|
||||||
|
engine_client.get_model_config = AsyncMock()
|
||||||
|
|
||||||
|
model_config = MagicMock()
|
||||||
|
model_config.hf_config.model_type = "test"
|
||||||
|
model_config.get_diff_sampling_param.return_value = {}
|
||||||
|
|
||||||
|
models = MagicMock()
|
||||||
|
|
||||||
|
# Create the actual instance
|
||||||
|
instance = OpenAIServingResponses(
|
||||||
|
engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
models=models,
|
||||||
|
request_logger=None,
|
||||||
|
chat_template=None,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set max_model_len for testing
|
||||||
|
instance.max_model_len = 100
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def test_validate_generator_input(self, serving_responses_instance):
|
||||||
|
"""Test _validate_generator_input with valid prompt length"""
|
||||||
|
# Create an engine prompt with valid length (less than max_model_len)
|
||||||
|
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
|
||||||
|
engine_prompt = EngineTokensPrompt(
|
||||||
|
prompt_token_ids=valid_prompt_token_ids)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = serving_responses_instance._validate_generator_input(
|
||||||
|
engine_prompt)
|
||||||
|
|
||||||
|
# Should return None for valid input
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# create an invalid engine prompt
|
||||||
|
invalid_prompt_token_ids = list(
|
||||||
|
range(200)) # 100 tokens >= 100 max_model_len
|
||||||
|
engine_prompt = EngineTokensPrompt(
|
||||||
|
prompt_token_ids=invalid_prompt_token_ids)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = serving_responses_instance._validate_generator_input(
|
||||||
|
engine_prompt)
|
||||||
|
|
||||||
|
# Should return an ErrorResponse
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, ErrorResponse)
|
||||||
|
|||||||
@ -192,6 +192,23 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
self.tool_server = tool_server
|
self.tool_server = tool_server
|
||||||
|
|
||||||
|
def _validate_generator_input(
|
||||||
|
self,
|
||||||
|
engine_prompt: EngineTokensPrompt) -> Optional[ErrorResponse]:
|
||||||
|
"""Add validations to the input to the generator here."""
|
||||||
|
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
|
||||||
|
error_message = (
|
||||||
|
"The engine prompt length"
|
||||||
|
f" {len(engine_prompt['prompt_token_ids'])} "
|
||||||
|
f"exceeds the max_model_len {self.max_model_len}. "
|
||||||
|
"Please reduce prompt.")
|
||||||
|
return self.create_error_response(
|
||||||
|
err_type="invalid_request_error",
|
||||||
|
message=error_message,
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
async def create_responses(
|
async def create_responses(
|
||||||
self,
|
self,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
@ -287,8 +304,13 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
available_tools = []
|
available_tools = []
|
||||||
try:
|
try:
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
|
maybe_error = self._validate_generator_input(engine_prompt)
|
||||||
|
if maybe_error is not None:
|
||||||
|
return maybe_error
|
||||||
|
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
engine_prompt["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
|
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens, self.default_sampling_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user