From 7134303cbbb7c82cdfcb0c87d59bb48fe6ad642b Mon Sep 17 00:00:00 2001 From: Roy Date: Sat, 27 Apr 2024 19:30:08 +0800 Subject: [PATCH] [Bugfix][Core] Fix get decoding config from ray (#4335) --- tests/async_engine/test_async_llm_engine.py | 2 + tests/async_engine/test_openapi_server_ray.py | 157 ++++++++++++++++++ vllm/engine/async_llm_engine.py | 10 +- vllm/engine/llm_engine.py | 4 + vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- 6 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 tests/async_engine/test_openapi_server_ray.py diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index cb125a7bfec3..b69cdc0a2140 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -91,4 +91,6 @@ async def test_new_requests_event(): assert engine.engine.step_calls == old_step_calls + 1 engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True) + assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None + assert engine.get_decoding_config() is not None diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py new file mode 100644 index 000000000000..4b97af88012b --- /dev/null +++ b/tests/async_engine/test_openapi_server_ray.py @@ -0,0 +1,157 @@ +# imports for guided decoding tests +import os +import subprocess +import sys +import time + +import openai # use the official client for correctness check +import pytest +# using Ray for overall ease of process management, parallel requests, +# and debugging. +import ray +import requests + +MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + + +@ray.remote(num_gpus=1) +class ServerRunner: + + def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +@pytest.fixture(scope="session") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--enforce-eager", + "--engine-use-ray" + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest.mark.asyncio +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + + +@pytest.mark.asyncio +async def test_single_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + + +@pytest.mark.asyncio +async def test_single_chat_session(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert chat_completion.id is not None + assert chat_completion.choices is not None and len( + chat_completion.choices) == 1 + assert chat_completion.choices[0].message is not None + assert chat_completion.choices[0].logprobs is not None + assert chat_completion.choices[0].logprobs.top_logprobs is not None + assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 89ee3f0db491..7c1eb2ecbe55 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,7 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, from transformers import PreTrainedTokenizer -from vllm.config import ModelConfig +from vllm.config import DecodingConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -697,6 +697,14 @@ class AsyncLLMEngine: else: return self.engine.get_model_config() + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_decoding_config.remote( # type: ignore + ) + else: + return self.engine.get_decoding_config() + async def do_log_stats(self) -> None: if self.engine_use_ray: await self.engine.do_log_stats.remote() # type: ignore diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 741d3bcd8089..292504631b06 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -467,6 +467,10 @@ class LLMEngine: """Gets the model configuration.""" return self.model_config + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return self.scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 629dd929dc1a..5ed042ef386e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -101,7 +101,7 @@ class OpenAIServingChat(OpenAIServing): request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - decoding_config = self.engine.engine.decoding_config + decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7904bb698c45..6a7f29c4c96f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -89,7 +89,7 @@ class OpenAIServingCompletion(OpenAIServing): try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - decoding_config = self.engine.engine.decoding_config + decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logit_processor = (