mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
83 lines
2.1 KiB
Python
83 lines
2.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import asyncio
|
|
import random
|
|
from typing import Callable
|
|
|
|
import openai
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from tests.utils import RemoteOpenAIServer
|
|
|
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server(): # noqa: F811
|
|
args = [
|
|
# use half precision for speed and memory savings in CI environment
|
|
"--dtype",
|
|
"bfloat16",
|
|
"--max-model-len",
|
|
"8192",
|
|
"--enforce-eager",
|
|
"--max-num-seqs",
|
|
"128",
|
|
"--load-format",
|
|
"dummy",
|
|
]
|
|
|
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
|
yield remote_server
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(server):
|
|
async with server.get_async_client() as async_client:
|
|
yield async_client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
ids=["completion", "chat"],
|
|
argnames=["create_func_gen", "content_body"],
|
|
argvalues=[
|
|
(lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}),
|
|
(
|
|
lambda x: x.chat.completions.create,
|
|
{"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]},
|
|
),
|
|
],
|
|
)
|
|
async def test_with_and_without_truncate(
|
|
server: RemoteOpenAIServer,
|
|
client: openai.AsyncOpenAI,
|
|
create_func_gen: Callable,
|
|
content_body: dict,
|
|
):
|
|
create_func = create_func_gen(client)
|
|
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
|
|
|
num_requests = 10
|
|
truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * (
|
|
num_requests - num_requests // 2
|
|
)
|
|
random.shuffle(truncate_prompt_tokens)
|
|
|
|
bodies = [
|
|
{**body, "extra_body": {"truncate_prompt_tokens": t}}
|
|
for t in truncate_prompt_tokens
|
|
]
|
|
|
|
async def get_status_code(**kwargs):
|
|
try:
|
|
await create_func(**kwargs)
|
|
return 200
|
|
except openai.APIStatusError as e:
|
|
return e.status_code
|
|
|
|
responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
|
|
assert 500 not in responses
|