[CI/Build] Simplify OpenAI server setup in tests (#5100)

This commit is contained in:
Cyrus Leung 2024-06-14 02:21:53 +08:00 committed by GitHub
parent 03dccc886e
commit 39873476f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 284 additions and 237 deletions

View File

@ -4,16 +4,22 @@ import pytest
# and debugging. # and debugging.
import ray import ray
from ..utils import ServerRunner from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
server_runner = ServerRunner.remote([ yield
ray.shutdown()
@pytest.fixture(scope="module")
def server(ray_ctx):
return RemoteOpenAIServer([
"--model", "--model",
MODEL_NAME, MODEL_NAME,
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
@ -24,22 +30,15 @@ def server():
"--enforce-eager", "--enforce-eager",
"--engine-use-ray" "--engine-use-ray"
]) ])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def client(): def client(server):
client = openai.AsyncOpenAI( return server.get_async_client()
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
yield client
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_models(server, client: openai.AsyncOpenAI): async def test_check_models(client: openai.AsyncOpenAI):
models = await client.models.list() models = await client.models.list()
models = models.data models = models.data
served_model = models[0] served_model = models[0]
@ -48,7 +47,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_completion(server, client: openai.AsyncOpenAI): async def test_single_completion(client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME, completion = await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
@ -72,7 +71,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_chat_session(server, client: openai.AsyncOpenAI): async def test_single_chat_session(client: openai.AsyncOpenAI):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"

View File

@ -0,0 +1,113 @@
import openai
import pytest
import ray
from ..utils import VLLM_PATH, RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
pytestmark = pytest.mark.openai
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module")
def embedding_server(ray_ctx):
return RemoteOpenAIServer([
"--model",
EMBEDDING_MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--enforce-eager",
"--max-model-len",
"8192",
"--enforce-eager",
])
@pytest.mark.asyncio
@pytest.fixture(scope="module")
def embedding_client(embedding_server):
return embedding_server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"The chef prepared a delicious meal.",
]
# test single embedding
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 9
assert embeddings.usage.total_tokens == 9
# test using token IDs
input_tokens = [1, 1, 1, 1, 1]
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 5
assert embeddings.usage.total_tokens == 5
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
# test List[str]
input_texts = [
"The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky."
]
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096
# test List[List[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 4
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17

View File

@ -15,11 +15,10 @@ from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ..utils import ServerRunner from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing # technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here # generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
@ -80,9 +79,15 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
server_runner = ServerRunner.remote([ yield
ray.shutdown()
@pytest.fixture(scope="module")
def server(zephyr_lora_files, ray_ctx):
return RemoteOpenAIServer([
"--model", "--model",
MODEL_NAME, MODEL_NAME,
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
@ -91,8 +96,6 @@ def server(zephyr_lora_files):
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager", "--enforce-eager",
"--gpu-memory-utilization",
"0.75",
# lora config below # lora config below
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
@ -105,43 +108,14 @@ def server(zephyr_lora_files):
"--max-num-seqs", "--max-num-seqs",
"128", "128",
]) ])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def embedding_server(zephyr_lora_files): def client(server):
ray.shutdown() return server.get_async_client()
ray.init()
server_runner = ServerRunner.remote([
"--model",
EMBEDDING_MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--enforce-eager",
"--gpu-memory-utilization",
"0.75",
"--max-model-len",
"8192",
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
@pytest.fixture(scope="module") async def test_check_models(client: openai.AsyncOpenAI):
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
yield client
@pytest.mark.asyncio
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list() models = await client.models.list()
models = models.data models = models.data
served_model = models[0] served_model = models[0]
@ -158,8 +132,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
) )
async def test_single_completion(server, client: openai.AsyncOpenAI, async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
@ -190,8 +163,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
) )
async def test_no_logprobs(server, client: openai.AsyncOpenAI, async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -210,8 +182,7 @@ async def test_no_logprobs(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_zero_logprobs(server, client: openai.AsyncOpenAI, async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -232,8 +203,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_some_logprobs(server, client: openai.AsyncOpenAI, async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -254,7 +224,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
with pytest.raises( with pytest.raises(
@ -300,8 +270,7 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
) )
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI, async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -326,8 +295,7 @@ async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI, async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -354,8 +322,7 @@ async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI, async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -382,7 +349,7 @@ async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -425,7 +392,7 @@ async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_single_chat_session(server, client: openai.AsyncOpenAI, async def test_single_chat_session(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -470,7 +437,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_completion_streaming(server, client: openai.AsyncOpenAI, async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
prompt = "What is an LLM?" prompt = "What is an LLM?"
@ -505,8 +472,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_chat_streaming(server, client: openai.AsyncOpenAI, async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -555,8 +521,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
) )
async def test_chat_completion_stream_options(server, async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
client: openai.AsyncOpenAI,
model_name: str): model_name: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -626,7 +591,7 @@ async def test_chat_completion_stream_options(server,
"model_name", "model_name",
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
) )
async def test_completion_stream_options(server, client: openai.AsyncOpenAI, async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
prompt = "What is the capital of France?" prompt = "What is the capital of France?"
@ -688,8 +653,7 @@ async def test_completion_stream_options(server, client: openai.AsyncOpenAI,
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_batch_completions(server, client: openai.AsyncOpenAI, async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
model_name: str):
# test simple list # test simple list
batch = await client.completions.create( batch = await client.completions.create(
model=model_name, model=model_name,
@ -737,7 +701,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_logits_bias(server, client: openai.AsyncOpenAI): async def test_logits_bias(client: openai.AsyncOpenAI):
prompt = "Hello, my name is" prompt = "Hello, my name is"
max_tokens = 5 max_tokens = 5
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
@ -786,7 +750,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI, async def test_guided_json_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -808,7 +772,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI, async def test_guided_json_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -855,7 +819,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, async def test_guided_regex_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -875,7 +839,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, async def test_guided_regex_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -913,7 +877,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, async def test_guided_choice_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -933,7 +897,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, async def test_guided_choice_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -972,7 +936,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
_ = await client.completions.create( _ = await client.completions.create(
@ -1008,7 +972,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -1040,7 +1004,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_named_tool_use(server, client: openai.AsyncOpenAI, async def test_named_tool_use(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -1131,7 +1095,7 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) @pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported( async def test_required_tool_use_not_yet_supported(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str): client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -1177,7 +1141,7 @@ async def test_required_tool_use_not_yet_supported(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) @pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools( async def test_inconsistent_tool_choice_and_tools(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str): client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -1223,7 +1187,7 @@ async def test_inconsistent_tool_choice_and_tools(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI): async def test_response_format_json_object(client: openai.AsyncOpenAI):
for _ in range(2): for _ in range(2):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -1243,7 +1207,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extra_fields(server, client: openai.AsyncOpenAI): async def test_extra_fields(client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info: with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create( await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -1259,7 +1223,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_complex_message_content(server, client: openai.AsyncOpenAI): async def test_complex_message_content(client: openai.AsyncOpenAI):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
@ -1279,7 +1243,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_role(server, client: openai.AsyncOpenAI): async def test_custom_role(client: openai.AsyncOpenAI):
# Not sure how the model handles custom roles so we just check that # Not sure how the model handles custom roles so we just check that
# both string and complex message content are handled in the same way # both string and complex message content are handled in the same way
@ -1310,7 +1274,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_guided_grammar(server, client: openai.AsyncOpenAI): async def test_guided_grammar(client: openai.AsyncOpenAI):
simple_sql_grammar = """ simple_sql_grammar = """
start: select_statement start: select_statement
@ -1351,7 +1315,7 @@ number: "1" | "2"
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
) )
@pytest.mark.parametrize("logprobs_arg", [1, 0]) @pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
model_name: str, logprobs_arg: int): model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs # test using text and token IDs
@ -1380,7 +1344,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_long_seed(server, client: openai.AsyncOpenAI): async def test_long_seed(client: openai.AsyncOpenAI):
for seed in [ for seed in [
torch.iinfo(torch.long).min - 1, torch.iinfo(torch.long).min - 1,
torch.iinfo(torch.long).max + 1 torch.iinfo(torch.long).max + 1
@ -1399,81 +1363,5 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message) or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"The chef prepared a delicious meal.",
]
# test single embedding
embeddings = await client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 9
assert embeddings.usage.total_tokens == 9
# test using token IDs
input_tokens = [1, 1, 1, 1, 1]
embeddings = await client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 5
assert embeddings.usage.total_tokens == 5
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
# test List[str]
input_texts = [
"The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky."
]
embeddings = await client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096
# test List[List[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
embeddings = await client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
)
assert embeddings.id is not None
assert len(embeddings.data) == 4
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@ -8,7 +8,7 @@ import ray
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64 from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
from ..utils import ServerRunner from ..utils import VLLM_PATH, RemoteOpenAIServer
MODEL_NAME = "llava-hf/llava-1.5-7b-hf" MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent / LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
@ -25,10 +25,16 @@ TEST_IMAGE_URLS = [
pytestmark = pytest.mark.openai pytestmark = pytest.mark.openai
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
ray.init() return RemoteOpenAIServer([
server_runner = ServerRunner.remote([
"--model", "--model",
MODEL_NAME, MODEL_NAME,
"--dtype", "--dtype",
@ -47,18 +53,11 @@ def server():
"--chat-template", "--chat-template",
str(LLAVA_CHAT_TEMPLATE), str(LLAVA_CHAT_TEMPLATE),
]) ])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
@pytest.fixture(scope="session") @pytest.fixture(scope="module")
def client(): def client(server):
client = openai.AsyncOpenAI( return server.get_async_client()
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
yield client
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
@ -73,7 +72,7 @@ async def base64_encoded_image() -> Dict[str, str]:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image(server, client: openai.AsyncOpenAI, async def test_single_chat_session_image(client: openai.AsyncOpenAI,
model_name: str, image_url: str): model_name: str, image_url: str):
messages = [{ messages = [{
"role": "role":
@ -126,7 +125,7 @@ async def test_single_chat_session_image(server, client: openai.AsyncOpenAI,
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_base64encoded( async def test_single_chat_session_image_base64encoded(
server, client: openai.AsyncOpenAI, model_name: str, image_url: str, client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: Dict[str, str]): base64_encoded_image: Dict[str, str]):
messages = [{ messages = [{
@ -180,7 +179,7 @@ async def test_single_chat_session_image_base64encoded(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_chat_streaming_image(server, client: openai.AsyncOpenAI, async def test_chat_streaming_image(client: openai.AsyncOpenAI,
model_name: str, image_url: str): model_name: str, image_url: str):
messages = [{ messages = [{
"role": "role":
@ -237,8 +236,8 @@ async def test_chat_streaming_image(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_multi_image_input(server, client: openai.AsyncOpenAI, async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
model_name: str, image_url: str): image_url: str):
messages = [{ messages = [{
"role": "role":

View File

@ -22,11 +22,12 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
tensorize_vllm_model) tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup from ..conftest import VllmRunner, cleanup
from ..utils import ServerRunner from ..utils import RemoteOpenAIServer
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
@ -216,18 +217,13 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
openai_args = [ openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format", "--model", model_ref, "--dtype", "float16", "--load-format",
"tensorizer", "--model-loader-extra-config", "tensorizer", "--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--port", "8000" json.dumps(model_loader_extra_config),
] ]
server = ServerRunner.remote(openai_args) server = RemoteOpenAIServer(openai_args)
assert ray.get(server.ready.remote())
print("Server ready.") print("Server ready.")
client = openai.OpenAI( client = server.get_client()
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
completion = client.completions.create(model=model_ref, completion = client.completions.create(model=model_ref,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,

View File

@ -4,57 +4,109 @@ import sys
import time import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import List
import openai
import ray import ray
import requests import requests
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port from vllm.utils import get_open_port
# Path to root of repository so that utilities can be imported by ray workers # Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
@ray.remote(num_gpus=1) class RemoteOpenAIServer:
class ServerRunner: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
def __init__(self, args): @ray.remote(num_gpus=1)
env = os.environ.copy() class _RemoteRunner:
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen( def __init__(self, cli_args: List[str], *, wait_url: str,
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] + wait_timeout: float) -> None:
args, env = os.environ.copy()
env=env, env["PYTHONUNBUFFERED"] = "1"
stdout=sys.stdout, self.proc = subprocess.Popen(
stderr=sys.stderr, [
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
*cli_args
],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server(url=wait_url, timeout=wait_timeout)
def ready(self):
return True
def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
while True:
try:
if requests.get(url).status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError(
"Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError(
"Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
if auto_port:
if "-p" in cli_args or "--port" in cli_args:
raise ValueError("You have manually specified the port"
"when `auto_port=True`.")
cli_args = cli_args + ["--port", str(get_open_port())]
parser = make_arg_parser()
args = parser.parse_args(cli_args)
self.host = str(args.host or 'localhost')
self.port = int(args.port)
self._runner = self._RemoteRunner.remote(
cli_args,
wait_url=self.url_for("health"),
wait_timeout=self.MAX_SERVER_START_WAIT_S)
self._wait_until_ready()
@property
def url_root(self) -> str:
return f"http://{self.host}:{self.port}"
def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts)
def _wait_until_ready(self) -> None:
ray.get(self._runner.ready.remote())
def get_client(self):
return openai.OpenAI(
base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
) )
self._wait_for_server()
def ready(self): def get_async_client(self):
return True return openai.AsyncOpenAI(
base_url=self.url_for("v1"),
def _wait_for_server(self): api_key=self.DUMMY_API_KEY,
# run health check )
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > self.MAX_SERVER_START_WAIT_S:
raise RuntimeError(
"Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
def init_test_distributed_environment( def init_test_distributed_environment(