diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 0cce880160920..58d92f72dfae0 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -8,9 +8,10 @@ import pytest import pytest_asyncio from vllm.entrypoints.context import ConversationContext -from vllm.entrypoints.openai.protocol import ResponsesRequest +from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.tool_server import ToolServer +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt class MockConversationContext(ConversationContext): @@ -127,3 +128,63 @@ class TestInitializeToolSessions: # Verify that init_tool_sessions was called assert mock_context.init_tool_sessions_called + + +class TestValidateGeneratorInput: + """Test class for _validate_generator_input method""" + + @pytest_asyncio.fixture + async def serving_responses_instance(self): + """Create a real OpenAIServingResponses instance for testing""" + # Create minimal mocks for required dependencies + engine_client = MagicMock() + engine_client.get_model_config = AsyncMock() + + model_config = MagicMock() + model_config.hf_config.model_type = "test" + model_config.get_diff_sampling_param.return_value = {} + + models = MagicMock() + + # Create the actual instance + instance = OpenAIServingResponses( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + ) + + # Set max_model_len for testing + instance.max_model_len = 100 + + return instance + + def test_validate_generator_input(self, serving_responses_instance): + """Test _validate_generator_input with valid prompt length""" + # Create an engine prompt with valid length (less than max_model_len) + valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len + engine_prompt = EngineTokensPrompt( + prompt_token_ids=valid_prompt_token_ids) + + # Call the method + result = serving_responses_instance._validate_generator_input( + engine_prompt) + + # Should return None for valid input + assert result is None + + # create an invalid engine prompt + invalid_prompt_token_ids = list( + range(200)) # 100 tokens >= 100 max_model_len + engine_prompt = EngineTokensPrompt( + prompt_token_ids=invalid_prompt_token_ids) + + # Call the method + result = serving_responses_instance._validate_generator_input( + engine_prompt) + + # Should return an ErrorResponse + assert result is not None + assert isinstance(result, ErrorResponse) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 4e7418920954a..faaed2fca3927 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -192,6 +192,23 @@ class OpenAIServingResponses(OpenAIServing): self.tool_server = tool_server + def _validate_generator_input( + self, + engine_prompt: EngineTokensPrompt) -> Optional[ErrorResponse]: + """Add validations to the input to the generator here.""" + if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): + error_message = ( + "The engine prompt length" + f" {len(engine_prompt['prompt_token_ids'])} " + f"exceeds the max_model_len {self.max_model_len}. " + "Please reduce prompt.") + return self.create_error_response( + err_type="invalid_request_error", + message=error_message, + status_code=HTTPStatus.BAD_REQUEST, + ) + return None + async def create_responses( self, request: ResponsesRequest, @@ -287,8 +304,13 @@ class OpenAIServingResponses(OpenAIServing): available_tools = [] try: for i, engine_prompt in enumerate(engine_prompts): + maybe_error = self._validate_generator_input(engine_prompt) + if maybe_error is not None: + return maybe_error + default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params)