mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 07:55:01 +08:00
Signed-off-by: Will Eaton <weaton@redhat.com> Signed-off-by: Will Eaton <me@wseaton.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com>
217 lines
6.2 KiB
Python
217 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass, field
|
|
from http import HTTPStatus
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from vllm.config.multimodal import MultiModalConfig
|
|
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
|
from vllm.outputs import CompletionOutput, RequestOutput
|
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
from vllm.v1.engine.async_llm import AsyncLLM
|
|
|
|
MODEL_NAME = "openai-community/gpt2"
|
|
MODEL_NAME_SHORT = "gpt2"
|
|
BASE_MODEL_PATHS = [
|
|
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
|
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class MockHFConfig:
|
|
model_type: str = "any"
|
|
|
|
|
|
@dataclass
|
|
class MockModelConfig:
|
|
task = "generate"
|
|
runner_type = "generate"
|
|
tokenizer = MODEL_NAME
|
|
trust_remote_code = False
|
|
tokenizer_mode = "auto"
|
|
max_model_len = 100
|
|
tokenizer_revision = None
|
|
multimodal_config = MultiModalConfig()
|
|
hf_config = MockHFConfig()
|
|
logits_processor_pattern = None
|
|
logits_processors: list[str] | None = None
|
|
diff_sampling_param: dict | None = None
|
|
allowed_local_media_path: str = ""
|
|
allowed_media_domains: list[str] | None = None
|
|
encoder_config = None
|
|
generation_config: str = "auto"
|
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
skip_tokenizer_init = False
|
|
|
|
def get_diff_sampling_param(self):
|
|
return self.diff_sampling_param or {}
|
|
|
|
|
|
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
|
models = OpenAIServingModels(
|
|
engine_client=engine,
|
|
base_model_paths=BASE_MODEL_PATHS,
|
|
)
|
|
serving_completion = OpenAIServingCompletion(
|
|
engine,
|
|
models,
|
|
request_logger=None,
|
|
)
|
|
|
|
async def _fake_process_inputs(
|
|
request_id,
|
|
engine_prompt,
|
|
sampling_params,
|
|
*,
|
|
lora_request,
|
|
trace_headers,
|
|
priority,
|
|
):
|
|
return dict(engine_prompt), {}
|
|
|
|
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
|
return serving_completion
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_error_non_stream():
|
|
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
|
mock_engine = MagicMock(spec=AsyncLLM)
|
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
|
mock_engine.errored = False
|
|
mock_engine.model_config = MockModelConfig()
|
|
mock_engine.input_processor = MagicMock()
|
|
mock_engine.io_processor = MagicMock()
|
|
|
|
serving_completion = _build_serving_completion(mock_engine)
|
|
|
|
completion_output = CompletionOutput(
|
|
index=0,
|
|
text="",
|
|
token_ids=[],
|
|
cumulative_logprob=None,
|
|
logprobs=None,
|
|
finish_reason="error",
|
|
)
|
|
|
|
request_output = RequestOutput(
|
|
request_id="test-id",
|
|
prompt="Test prompt",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[completion_output],
|
|
finished=True,
|
|
metrics=None,
|
|
lora_request=None,
|
|
encoder_prompt=None,
|
|
encoder_prompt_token_ids=None,
|
|
)
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield request_output
|
|
|
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
|
|
|
request = CompletionRequest(
|
|
model=MODEL_NAME,
|
|
prompt="Test prompt",
|
|
max_tokens=10,
|
|
stream=False,
|
|
)
|
|
|
|
response = await serving_completion.create_completion(request)
|
|
|
|
assert isinstance(response, ErrorResponse)
|
|
assert response.error.type == "InternalServerError"
|
|
assert response.error.message == "Internal server error"
|
|
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_error_stream():
|
|
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
|
mock_engine = MagicMock(spec=AsyncLLM)
|
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
|
mock_engine.errored = False
|
|
mock_engine.model_config = MockModelConfig()
|
|
mock_engine.input_processor = MagicMock()
|
|
mock_engine.io_processor = MagicMock()
|
|
|
|
serving_completion = _build_serving_completion(mock_engine)
|
|
|
|
completion_output_1 = CompletionOutput(
|
|
index=0,
|
|
text="Hello",
|
|
token_ids=[100],
|
|
cumulative_logprob=None,
|
|
logprobs=None,
|
|
finish_reason=None,
|
|
)
|
|
|
|
request_output_1 = RequestOutput(
|
|
request_id="test-id",
|
|
prompt="Test prompt",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[completion_output_1],
|
|
finished=False,
|
|
metrics=None,
|
|
lora_request=None,
|
|
encoder_prompt=None,
|
|
encoder_prompt_token_ids=None,
|
|
)
|
|
|
|
completion_output_2 = CompletionOutput(
|
|
index=0,
|
|
text="Hello",
|
|
token_ids=[100],
|
|
cumulative_logprob=None,
|
|
logprobs=None,
|
|
finish_reason="error",
|
|
)
|
|
|
|
request_output_2 = RequestOutput(
|
|
request_id="test-id",
|
|
prompt="Test prompt",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[completion_output_2],
|
|
finished=True,
|
|
metrics=None,
|
|
lora_request=None,
|
|
encoder_prompt=None,
|
|
encoder_prompt_token_ids=None,
|
|
)
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield request_output_1
|
|
yield request_output_2
|
|
|
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
|
|
|
request = CompletionRequest(
|
|
model=MODEL_NAME,
|
|
prompt="Test prompt",
|
|
max_tokens=10,
|
|
stream=True,
|
|
)
|
|
|
|
response = await serving_completion.create_completion(request)
|
|
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) >= 2
|
|
assert any("Internal server error" in chunk for chunk in chunks), (
|
|
f"Expected error message in chunks: {chunks}"
|
|
)
|
|
assert chunks[-1] == "data: [DONE]\n\n"
|