mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
Start background task in AsyncLLMEngine.generate (#988)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
4b5bcf8906
commit
080438477f
@ -40,8 +40,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(
|
||||
app,
|
||||
|
||||
@ -230,6 +230,8 @@ class AsyncLLMEngine:
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
|
||||
@ -240,7 +242,7 @@ class AsyncLLMEngine:
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = False,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
@ -249,8 +251,7 @@ class AsyncLLMEngine:
|
||||
|
||||
self.request_tracker: RequestTracker = RequestTracker()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
self.start_engine_loop = start_engine_loop
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -330,11 +331,14 @@ class AsyncLLMEngine:
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
else:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
stream = self.request_tracker.add_request(
|
||||
request_id,
|
||||
@ -426,7 +430,7 @@ class AsyncLLMEngine:
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
|
||||
@ -32,9 +32,6 @@ async def generate(request: Request) -> Response:
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
@ -80,8 +77,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
|
||||
@ -192,9 +192,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
"""
|
||||
logger.info(f"Received chat completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -367,9 +364,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -627,8 +621,7 @@ if __name__ == "__main__":
|
||||
served_model = args.model
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.get_max_model_len()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user