mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
Initialize AsyncLLMEngine bg loop correctly (#943)
This commit is contained in:
parent
002800f081
commit
1696725879
@ -155,11 +155,15 @@ class AsyncLLMEngine:
|
|||||||
self.finished_requests: Set[str] = set()
|
self.finished_requests: Set[str] = set()
|
||||||
self.background_loop = None
|
self.background_loop = None
|
||||||
if start_engine_loop:
|
if start_engine_loop:
|
||||||
self._start_background_loop()
|
self.start_background_loop()
|
||||||
|
|
||||||
def _start_background_loop(self) -> None:
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return self.background_loop is not None
|
||||||
|
|
||||||
|
def start_background_loop(self) -> None:
|
||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
if self.background_loop is not None:
|
if self.is_running:
|
||||||
raise RuntimeError("Background loop is already running.")
|
raise RuntimeError("Background loop is already running.")
|
||||||
self.background_loop = asyncio.get_event_loop().create_task(
|
self.background_loop = asyncio.get_event_loop().create_task(
|
||||||
self.run_engine_loop())
|
self.run_engine_loop())
|
||||||
@ -323,7 +327,8 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(cls,
|
def from_engine_args(cls,
|
||||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
engine_args: AsyncEngineArgs,
|
||||||
|
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
@ -338,5 +343,6 @@ class AsyncLLMEngine:
|
|||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
placement_group,
|
placement_group,
|
||||||
log_requests=not engine_args.disable_log_requests,
|
log_requests=not engine_args.disable_log_requests,
|
||||||
log_stats=not engine_args.disable_log_stats)
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
start_engine_loop=start_engine_loop)
|
||||||
return engine
|
return engine
|
||||||
|
|||||||
@ -30,6 +30,10 @@ async def generate(request: Request) -> Response:
|
|||||||
stream = request_dict.pop("stream", False)
|
stream = request_dict.pop("stream", False)
|
||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
|
|
||||||
|
if not engine.is_running:
|
||||||
|
engine.start_background_loop()
|
||||||
|
|
||||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
@ -75,7 +79,8 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||||
|
start_engine_loop=False)
|
||||||
|
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
|||||||
@ -191,6 +191,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Received chat completion request: {request}")
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
|
||||||
|
if not engine.is_running:
|
||||||
|
engine.start_background_loop()
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -363,6 +366,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Received completion request: {request}")
|
logger.info(f"Received completion request: {request}")
|
||||||
|
|
||||||
|
if not engine.is_running:
|
||||||
|
engine.start_background_loop()
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -620,7 +626,8 @@ if __name__ == "__main__":
|
|||||||
served_model = args.model
|
served_model = args.model
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||||
|
start_engine_loop=False)
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
max_model_len = engine_model_config.get_max_model_len()
|
max_model_len = engine_model_config.get_max_model_len()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user