mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[Tests] Disable retries and use context manager for openai client (#7565)
This commit is contained in:
parent
2eedede875
commit
39178c7fbc
@ -1,5 +1,6 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
@ -31,9 +32,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
||||
@ -28,9 +29,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@ -2,6 +2,7 @@ from http import HTTPStatus
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@ -28,9 +29,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Dict, List, Optional
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
from openai import BadRequestError
|
||||
|
||||
@ -46,9 +47,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Dict, List, Optional
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
@ -89,11 +90,17 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["", "--disable-frontend-multiprocessing"])
|
||||
def client(default_server_args, request):
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server.get_async_client()
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -3,6 +3,7 @@ import base64
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -24,10 +25,10 @@ def embedding_server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.fixture(scope="module")
|
||||
def embedding_client(embedding_server):
|
||||
return embedding_server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def embedding_client(embedding_server):
|
||||
async with embedding_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -18,9 +19,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -6,6 +6,7 @@ from http import HTTPStatus
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
from transformers import AutoTokenizer
|
||||
@ -35,11 +36,17 @@ def default_server_args():
|
||||
"--enable-chunked-prefill",
|
||||
"--disable-frontend-multiprocessing",
|
||||
])
|
||||
def client(default_server_args, request):
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server.get_async_client()
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as cl:
|
||||
yield cl
|
||||
|
||||
|
||||
_PROMPT = "Hello my name is Robert and I love magic"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@ -43,9 +44,10 @@ def server(zephyr_lora_files):
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -25,59 +25,63 @@ def server_with_return_tokens_as_token_ids_flag(
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_return_tokens_as_token_ids_completion(
|
||||
server_with_return_tokens_as_token_ids_flag):
|
||||
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
|
||||
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
|
||||
) as client:
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
prompt="Say 'Hello, world! 🎉'",
|
||||
echo=True,
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
logprobs=1)
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
prompt="Say 'Hello, world! 🎉'",
|
||||
echo=True,
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
logprobs=1)
|
||||
|
||||
text = completion.choices[0].text
|
||||
token_strs = completion.choices[0].logprobs.tokens
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# Check that the token representations are consistent between raw tokens
|
||||
# and top_logprobs
|
||||
# Slice off the first one, because there's no scoring associated with BOS
|
||||
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
|
||||
top_logprob_keys = [
|
||||
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
|
||||
]
|
||||
assert token_strs[1:] == top_logprob_keys
|
||||
text = completion.choices[0].text
|
||||
token_strs = completion.choices[0].logprobs.tokens
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# Check that the token representations are consistent between raw
|
||||
# tokens and top_logprobs
|
||||
# Slice off the first one, because there's no scoring associated
|
||||
# with BOS
|
||||
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
|
||||
top_logprob_keys = [
|
||||
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
|
||||
]
|
||||
assert token_strs[1:] == top_logprob_keys
|
||||
|
||||
# Check that decoding the tokens gives the expected text
|
||||
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
|
||||
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
# Check that decoding the tokens gives the expected text
|
||||
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
|
||||
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_return_tokens_as_token_ids_completion(
|
||||
server_with_return_tokens_as_token_ids_flag):
|
||||
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You like to respond in only emojis, like 🎉"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Please write some emojis: 🐱🐶🎉"
|
||||
}],
|
||||
temperature=0,
|
||||
max_tokens=8,
|
||||
logprobs=True)
|
||||
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
|
||||
) as client:
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You like to respond in only emojis, like 🎉"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Please write some emojis: 🐱🐶🎉"
|
||||
}],
|
||||
temperature=0,
|
||||
max_tokens=8,
|
||||
logprobs=True)
|
||||
|
||||
text = response.choices[0].message.content
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
token_ids = []
|
||||
for logprob_content in response.choices[0].logprobs.content:
|
||||
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
|
||||
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
|
||||
text = response.choices[0].message.content
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
token_ids = []
|
||||
for logprob_content in response.choices[0].logprobs.content:
|
||||
token_ids.append(
|
||||
int(logprob_content.token.removeprefix("token_id:")))
|
||||
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
|
||||
|
||||
@ -35,13 +35,14 @@ async def test_shutdown_on_engine_failure(tmp_path):
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
async with remote_server.get_async_client() as client:
|
||||
|
||||
with pytest.raises(openai.APIConnectionError):
|
||||
# This crashes the engine
|
||||
await client.completions.create(model="bad-adapter",
|
||||
prompt="Hello, my name is")
|
||||
with pytest.raises(
|
||||
(openai.APIConnectionError, openai.InternalServerError)):
|
||||
# This crashes the engine
|
||||
await client.completions.create(model="bad-adapter",
|
||||
prompt="Hello, my name is")
|
||||
|
||||
# Now the server should shut down
|
||||
return_code = remote_server.proc.wait(timeout=1)
|
||||
assert return_code is not None
|
||||
# Now the server should shut down
|
||||
return_code = remote_server.proc.wait(timeout=3)
|
||||
assert return_code is not None
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
@ -42,9 +43,10 @@ def tokenizer_name(model_name: str,
|
||||
model_name == "zephyr-lora2") else model_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
@ -36,9 +37,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@ -28,12 +28,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
|
||||
|
||||
outputs = None
|
||||
with RemoteOpenAIServer(model_name, server_cli_args) as server:
|
||||
client = server.get_async_client()
|
||||
outputs = await client.completions.create(model=model_name,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=5)
|
||||
async with server.get_async_client() as client:
|
||||
outputs = await client.completions.create(model=model_name,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=5)
|
||||
assert outputs is not None
|
||||
|
||||
return outputs
|
||||
|
||||
@ -154,6 +154,7 @@ class RemoteOpenAIServer:
|
||||
return openai.AsyncOpenAI(
|
||||
base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user