# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field from http import HTTPStatus from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" MODEL_NAME_SHORT = "gpt2" BASE_MODEL_PATHS = [ BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT), ] @dataclass class MockHFConfig: model_type: str = "any" @dataclass class MockModelConfig: task = "generate" runner_type = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None multimodal_config = MultiModalConfig() hf_config = MockHFConfig() logits_processor_pattern = None logits_processors: list[str] | None = None diff_sampling_param: dict | None = None allowed_local_media_path: str = "" allowed_media_domains: list[str] | None = None encoder_config = None generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) skip_tokenizer_init = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: models = OpenAIServingModels( engine_client=engine, base_model_paths=BASE_MODEL_PATHS, ) serving_completion = OpenAIServingCompletion( engine, models, request_logger=None, ) async def _fake_process_inputs( request_id, engine_prompt, sampling_params, *, lora_request, trace_headers, priority, ): return dict(engine_prompt), {} serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs) return serving_completion @pytest.mark.asyncio async def test_completion_error_non_stream(): """test finish_reason='error' returns 500 InternalServerError (non-streaming)""" mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() serving_completion = _build_serving_completion(mock_engine) completion_output = CompletionOutput( index=0, text="", token_ids=[], cumulative_logprob=None, logprobs=None, finish_reason="error", ) request_output = RequestOutput( request_id="test-id", prompt="Test prompt", prompt_token_ids=[1, 2, 3], prompt_logprobs=None, outputs=[completion_output], finished=True, metrics=None, lora_request=None, encoder_prompt=None, encoder_prompt_token_ids=None, ) async def mock_generate(*args, **kwargs): yield request_output mock_engine.generate = MagicMock(side_effect=mock_generate) request = CompletionRequest( model=MODEL_NAME, prompt="Test prompt", max_tokens=10, stream=False, ) response = await serving_completion.create_completion(request) assert isinstance(response, ErrorResponse) assert response.error.type == "InternalServerError" assert response.error.message == "Internal server error" assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR @pytest.mark.asyncio async def test_completion_error_stream(): """test finish_reason='error' returns 500 InternalServerError (streaming)""" mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() serving_completion = _build_serving_completion(mock_engine) completion_output_1 = CompletionOutput( index=0, text="Hello", token_ids=[100], cumulative_logprob=None, logprobs=None, finish_reason=None, ) request_output_1 = RequestOutput( request_id="test-id", prompt="Test prompt", prompt_token_ids=[1, 2, 3], prompt_logprobs=None, outputs=[completion_output_1], finished=False, metrics=None, lora_request=None, encoder_prompt=None, encoder_prompt_token_ids=None, ) completion_output_2 = CompletionOutput( index=0, text="Hello", token_ids=[100], cumulative_logprob=None, logprobs=None, finish_reason="error", ) request_output_2 = RequestOutput( request_id="test-id", prompt="Test prompt", prompt_token_ids=[1, 2, 3], prompt_logprobs=None, outputs=[completion_output_2], finished=True, metrics=None, lora_request=None, encoder_prompt=None, encoder_prompt_token_ids=None, ) async def mock_generate(*args, **kwargs): yield request_output_1 yield request_output_2 mock_engine.generate = MagicMock(side_effect=mock_generate) request = CompletionRequest( model=MODEL_NAME, prompt="Test prompt", max_tokens=10, stream=True, ) response = await serving_completion.create_completion(request) chunks = [] async for chunk in response: chunks.append(chunk) assert len(chunks) >= 2 assert any("Internal server error" in chunk for chunk in chunks), ( f"Expected error message in chunks: {chunks}" ) assert chunks[-1] == "data: [DONE]\n\n"