Initialize AsyncLLMEngine bg loop correctly (#943)

This commit is contained in:
Antoni Baum 2023-09-04 17:41:22 -07:00 committed by GitHub
parent 002800f081
commit 1696725879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 7 deletions

View File

@ -155,11 +155,15 @@ class AsyncLLMEngine:
self.finished_requests: Set[str] = set() self.finished_requests: Set[str] = set()
self.background_loop = None self.background_loop = None
if start_engine_loop: if start_engine_loop:
self._start_background_loop() self.start_background_loop()
def _start_background_loop(self) -> None: @property
def is_running(self) -> bool:
return self.background_loop is not None
def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
if self.background_loop is not None: if self.is_running:
raise RuntimeError("Background loop is already running.") raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task( self.background_loop = asyncio.get_event_loop().create_task(
self.run_engine_loop()) self.run_engine_loop())
@ -323,7 +327,8 @@ class AsyncLLMEngine:
@classmethod @classmethod
def from_engine_args(cls, def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": engine_args: AsyncEngineArgs,
start_engine_loop: bool = False) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
@ -338,5 +343,6 @@ class AsyncLLMEngine:
distributed_init_method, distributed_init_method,
placement_group, placement_group,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop)
return engine return engine

View File

@ -30,6 +30,10 @@ async def generate(request: Request) -> Response:
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
if not engine.is_running:
engine.start_background_loop()
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
@ -75,7 +79,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args,
start_engine_loop=False)
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,

View File

@ -191,6 +191,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
""" """
logger.info(f"Received chat completion request: {request}") logger.info(f"Received chat completion request: {request}")
if not engine.is_running:
engine.start_background_loop()
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -363,6 +366,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
""" """
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")
if not engine.is_running:
engine.start_background_loop()
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -620,7 +626,8 @@ if __name__ == "__main__":
served_model = args.model served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args,
start_engine_loop=False)
engine_model_config = asyncio.run(engine.get_model_config()) engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len() max_model_len = engine_model_config.get_max_model_len()