From 1696725879f25de03ca36cde764102bba60ff681 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Sep 2023 17:41:22 -0700 Subject: [PATCH] Initialize AsyncLLMEngine bg loop correctly (#943) --- vllm/engine/async_llm_engine.py | 16 +++++++++++----- vllm/entrypoints/api_server.py | 7 ++++++- vllm/entrypoints/openai/api_server.py | 9 ++++++++- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 54f3867694e5..d4ce8597f31e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -155,11 +155,15 @@ class AsyncLLMEngine: self.finished_requests: Set[str] = set() self.background_loop = None if start_engine_loop: - self._start_background_loop() + self.start_background_loop() - def _start_background_loop(self) -> None: + @property + def is_running(self) -> bool: + return self.background_loop is not None + + def start_background_loop(self) -> None: """Start the background loop.""" - if self.background_loop is not None: + if self.is_running: raise RuntimeError("Background loop is already running.") self.background_loop = asyncio.get_event_loop().create_task( self.run_engine_loop()) @@ -323,7 +327,8 @@ class AsyncLLMEngine: @classmethod def from_engine_args(cls, - engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": + engine_args: AsyncEngineArgs, + start_engine_loop: bool = False) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_configs = engine_args.create_engine_configs() @@ -338,5 +343,6 @@ class AsyncLLMEngine: distributed_init_method, placement_group, log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats) + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop) return engine diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 58ea2e229125..f430c0fff51e 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -30,6 +30,10 @@ async def generate(request: Request) -> Response: stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + + if not engine.is_running: + engine.start_background_loop() + results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case @@ -75,7 +79,8 @@ if __name__ == "__main__": args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args(engine_args) + engine = AsyncLLMEngine.from_engine_args(engine_args, + start_engine_loop=False) uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 385bff679e83..1629660097e6 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -191,6 +191,9 @@ async def create_chat_completion(request: ChatCompletionRequest, """ logger.info(f"Received chat completion request: {request}") + if not engine.is_running: + engine.start_background_loop() + error_check_ret = await check_model(request) if error_check_ret is not None: return error_check_ret @@ -363,6 +366,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request): """ logger.info(f"Received completion request: {request}") + if not engine.is_running: + engine.start_background_loop() + error_check_ret = await check_model(request) if error_check_ret is not None: return error_check_ret @@ -620,7 +626,8 @@ if __name__ == "__main__": served_model = args.model engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args(engine_args) + engine = AsyncLLMEngine.from_engine_args(engine_args, + start_engine_loop=False) engine_model_config = asyncio.run(engine.get_model_config()) max_model_len = engine_model_config.get_max_model_len()