[BugFix] Fix clean shutdown issues (#8492)

This commit is contained in:
Nick Hill 2024-09-16 17:33:46 +01:00 committed by GitHub
parent 837c1968f9
commit acd5511b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 213 additions and 134 deletions

View File

@ -26,6 +26,11 @@ class RequestOutput:
finished: bool = False finished: bool = False
@dataclass
class MockModelConfig:
use_async_output_proc = True
class MockEngine: class MockEngine:
def __init__(self): def __init__(self):
@ -35,6 +40,7 @@ class MockEngine:
self.request_id = None self.request_id = None
# Ugly, remove dependency when possible # Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False) self.parallel_config = ParallelConfig(1, 1, False)
self.model_config = MockModelConfig()
async def step_async(self, virtual_engine): async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine # PP size is 1, ignore virtual engine
@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_requests_event(): async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False) engine = MockAsyncLLMEngine()
engine.start_background_loop() engine.start_background_loop()
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0 assert engine.engine.step_calls == 0
@ -113,7 +119,7 @@ async def test_new_requests_event():
assert engine.engine.add_request_calls == 3 assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1 assert engine.engine.step_calls == old_step_calls + 1
engine = MockAsyncLLMEngine(worker_use_ray=True) engine = MockAsyncLLMEngine()
assert engine.get_model_config() is not None assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None assert engine.get_decoding_config() is not None

View File

@ -1,8 +1,10 @@
import asyncio import asyncio
import time import time
import weakref
from functools import partial from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller. method yields the outputs from the :class:`LLMEngine` to the caller.
Args: Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self, def __init__(self,
worker_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.log_requests = log_requests self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs) self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs # This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed, # so the append to asyncio queues is not delayed,
# especially for multi-step. # especially for multi-step.
# self.use_process_request_outputs_callback = (
self.use_process_request_outputs_callback = True self.engine.model_config.use_async_output_proc)
if self.use_process_request_outputs_callback: if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \ self.engine.process_request_outputs_callback = \
self.process_request_outputs weak_bind(self.process_request_outputs)
self.background_loop: Optional[asyncio.Future] = None self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded # We need to keep a reference to unshielded
@ -492,6 +491,11 @@ class AsyncLLMEngine:
# Lazy initialized fields # Lazy initialized fields
self._request_tracker: RequestTracker self._request_tracker: RequestTracker
def __del__(self):
if rt := getattr(self, "request_tracker", None):
# Wake up engine loop so that it will exit cleanly
rt.new_requests_event.set()
@classmethod @classmethod
def _get_executor_cls( def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
@ -502,15 +506,12 @@ class AsyncLLMEngine:
raise TypeError( raise TypeError(
"distributed_executor_backend must be a subclass of " "distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.") f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron": elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu": elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray": if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync executor_class = RayTPUExecutorAsync
else: else:
@ -531,11 +532,9 @@ class AsyncLLMEngine:
from vllm.executor.xpu_executor import XPUExecutorAsync from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray": elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import ( from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync) MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync executor_class = MultiprocessingXPUExecutorAsync
@ -543,7 +542,6 @@ class AsyncLLMEngine:
raise RuntimeError( raise RuntimeError(
"Not supported distributed execution model on XPU device.") "Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray": elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
@ -559,19 +557,23 @@ class AsyncLLMEngine:
def from_engine_args( def from_engine_args(
cls, cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine": ) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config() if engine_config is None:
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config) executor_class = cls._get_executor_cls(engine_config)
if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
executor_class.uses_ray,
**engine_config.to_dict(), **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
@ -628,7 +630,7 @@ class AsyncLLMEngine:
self._request_tracker = RequestTracker() self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop( self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop()) ).create_task(self.run_engine_loop(weakref.ref(self)))
self._background_loop_unshielded.add_done_callback( self._background_loop_unshielded.add_done_callback(
partial(_log_task_completion, error_callback=self._error_callback)) partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded) self.background_loop = asyncio.shield(self._background_loop_unshielded)
@ -698,9 +700,16 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): @staticmethod
async def run_engine_loop(engine_ref: ReferenceType):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine: Optional["AsyncLLMEngine"] = engine_ref()
if not engine:
return
pipeline_parallel_size = \ pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size engine.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size
while True: while True:
if not any(has_requests_in_progress): if not any(has_requests_in_progress):
@ -711,11 +720,21 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that # timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages, # they can process any other queued control plane messages,
# such as add/remove lora adapters. # such as add/remove lora adapters.
await self.engine.stop_remote_worker_execution_loop_async() await engine.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests() request_tracker = engine._request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del engine
await asyncio.sleep(0)
if engine_ref() is None:
return
await request_tracker.wait_for_new_requests()
engine = engine_ref()
if not engine:
return
logger.debug("Got new requests!") logger.debug("Got new requests!")
requests_in_progress = [ requests_in_progress = [
asyncio.create_task(self.engine_step(ve)) asyncio.create_task(engine.engine_step(ve))
for ve in range(pipeline_parallel_size) for ve in range(pipeline_parallel_size)
] ]
has_requests_in_progress = [True] * pipeline_parallel_size has_requests_in_progress = [True] * pipeline_parallel_size
@ -733,19 +752,20 @@ class AsyncLLMEngine:
result = task.result() result = task.result()
virtual_engine = requests_in_progress.index(task) virtual_engine = requests_in_progress.index(task)
has_unfinished_requests = ( has_unfinished_requests = (
self.engine.has_unfinished_requests_for_virtual_engine( engine.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine)) virtual_engine))
if result or has_unfinished_requests: if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = ( requests_in_progress[virtual_engine] = (
asyncio.create_task( asyncio.create_task(
self.engine_step(virtual_engine))) engine.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True has_requests_in_progress[virtual_engine] = True
else: else:
has_requests_in_progress[virtual_engine] = False has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc: except asyncio.TimeoutError as exc:
logger.error( logger.error(
"Engine iteration timed out. This should never happen!") "Engine iteration timed out. This should never happen!")
self.set_errored(exc) engine.set_errored(exc)
raise raise
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@ -1,8 +1,8 @@
import functools
import time import time
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device from vllm.utils import Counter, Device, weak_bind
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
@ -382,11 +382,16 @@ class LLMEngine:
for _ in range(self.parallel_config.pipeline_parallel_size) for _ in range(self.parallel_config.pipeline_parallel_size)
] ]
self.async_callbacks = [ if model_config.use_async_output_proc:
functools.partial(self._process_model_outputs, process_model_outputs = weak_bind(self._process_model_outputs)
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size) self.async_callbacks = [
] partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append # Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues # of request outputs to asyncio queues
@ -869,8 +874,8 @@ class LLMEngine:
""" """
return self.scheduler[virtual_engine].has_unfinished_seqs() return self.scheduler[virtual_engine].has_unfinished_seqs()
@staticmethod
def _process_sequence_group_outputs( def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput], outputs: List[EmbeddingSequenceGroupOutput],
) -> None: ) -> None:

View File

@ -1,21 +1,20 @@
import asyncio import asyncio
import signal import signal
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any, Optional
import uvicorn import uvicorn
from fastapi import FastAPI, Response from fastapi import FastAPI, Request, Response
from vllm import envs from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import find_process_using_port from vllm.utils import find_process_using_port
logger = init_logger(__name__) logger = init_logger(__name__)
async def serve_http(app: FastAPI, engine: AsyncEngineClient, async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
**uvicorn_kwargs: Any): **uvicorn_kwargs: Any):
logger.info("Available routes are:") logger.info("Available routes are:")
for route in app.routes: for route in app.routes:
@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
# Set concurrency limits in uvicorn if running in multiprocessing mode # Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if engine.limit_concurrency is not None: if limit_concurrency is not None:
logger.info( logger.info(
"Launching Uvicorn with --limit_concurrency %s. To avoid this " "Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with " "limit at the expense of performance run with "
"--disable-frontend-multiprocessing", engine.limit_concurrency) "--disable-frontend-multiprocessing", limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency uvicorn_kwargs["limit_concurrency"] = limit_concurrency
config = uvicorn.Config(app, **uvicorn_kwargs) config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config) server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine) _add_shutdown_handlers(app, server)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -68,15 +67,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
return server.shutdown() return server.shutdown()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
engine: AsyncEngineClient) -> None:
"""Adds handlers for fatal errors that should crash the server""" """Adds handlers for fatal errors that should crash the server"""
@app.exception_handler(RuntimeError) @app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __): async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died. """On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM.""" handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running): and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server " logger.fatal("AsyncLLMEngine has failed, terminating server "

View File

@ -4,16 +4,20 @@ import inspect
import multiprocessing import multiprocessing
import os import os
import re import re
import signal
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Optional, Set from typing import AsyncIterator, Optional, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.datastructures import State
from starlette.routing import Mount from starlette.routing import Mount
from typing_extensions import assert_never from typing_extensions import assert_never
@ -54,12 +58,6 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
prometheus_multiproc_dir: tempfile.TemporaryDirectory prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
async_engine_client = app.state.engine_client
async def _force_log(): async def _force_log():
while True: while True:
await asyncio.sleep(10) await asyncio.sleep(10)
await async_engine_client.do_log_stats() await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log())
task = asyncio.create_task(_force_log()) _running_tasks.add(task)
_running_tasks.add(task) task.add_done_callback(_running_tasks.remove)
task.add_done_callback(_running_tasks.remove) else:
task = None
yield try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager @asynccontextmanager
@ -103,16 +111,10 @@ async def build_async_engine_client(
# Context manager to handle async_engine_client lifecycle # Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine: engine_args, args.disable_frontend_multiprocessing) as engine:
async_engine_client = engine # type: ignore[assignment]
yield engine yield engine
@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args(
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization, engine_args.revision) engine_args.quantization, engine_args.revision)
or disable_frontend_multiprocessing): or disable_frontend_multiprocessing):
engine_client = AsyncLLMEngine.from_engine_args( engine_config = engine_args.create_engine_config()
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
try: "uses_ray", False)
yield engine_client
finally: build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_client.shutdown_background_loop() engine_args=engine_args,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
return return
# Otherwise, use the multiprocessing AsyncLLMEngine. # Otherwise, use the multiprocessing AsyncLLMEngine.
@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI):
app.routes.append(metrics_route) app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
return request.app.state.openai_serving_completion
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> AsyncEngineClient:
return request.app.state.engine_client
@router.get("/health") @router.get("/health")
async def health() -> Response: async def health(raw_request: Request) -> Response:
"""Health check.""" """Health check."""
await async_engine_client.check_health() await engine_client(raw_request).check_health()
return Response(status_code=200) return Response(status_code=200)
@router.post("/tokenize") @router.post("/tokenize")
async def tokenize(request: TokenizeRequest): async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await openai_serving_tokenization.create_tokenize(request) generator = await tokenization(raw_request).create_tokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest):
@router.post("/detokenize") @router.post("/detokenize")
async def detokenize(request: DetokenizeRequest): async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await openai_serving_tokenization.create_detokenize(request) generator = await tokenization(raw_request).create_detokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest):
@router.get("/v1/models") @router.get("/v1/models")
async def show_available_models(): async def show_available_models(raw_request: Request):
models = await openai_serving_completion.show_available_models() models = await completion(raw_request).show_available_models()
return JSONResponse(content=models.model_dump()) return JSONResponse(content=models.model_dump())
@ -288,7 +320,7 @@ async def show_version():
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
generator = await openai_serving_chat.create_chat_completion( generator = await chat(raw_request).create_chat_completion(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions") @router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion( generator = await completion(raw_request).create_completion(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding( generator = await embedding(raw_request).create_embedding(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
@ -333,16 +365,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
"used for local development!") "used for local development!")
@router.post("/start_profile") @router.post("/start_profile")
async def start_profile(): async def start_profile(raw_request: Request):
logger.info("Starting profiler...") logger.info("Starting profiler...")
await async_engine_client.start_profile() await engine_client(raw_request).start_profile()
logger.info("Profiler started.") logger.info("Profiler started.")
return Response(status_code=200) return Response(status_code=200)
@router.post("/stop_profile") @router.post("/stop_profile")
async def stop_profile(): async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...") logger.info("Stopping profiler...")
await async_engine_client.stop_profile() await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.") logger.info("Profiler stopped.")
return Response(status_code=200) return Response(status_code=200)
@ -353,13 +385,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"This should ONLY be used for local development!") "This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter") @router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest): async def load_lora_adapter(request: LoadLoraAdapterRequest,
response = await openai_serving_chat.load_lora_adapter(request) raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
response = await openai_serving_completion.load_lora_adapter(request) response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
@ -367,13 +400,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter") @router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest): async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
response = await openai_serving_chat.unload_lora_adapter(request) raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
response = await openai_serving_completion.unload_lora_adapter(request) response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc): async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc)) chat = app.state.openai_serving_chat
err = chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
@ -430,30 +465,26 @@ def build_app(args: Namespace) -> FastAPI:
return app return app
async def init_app( def init_app_state(
async_engine_client: AsyncEngineClient, async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
state: State,
args: Namespace, args: Namespace,
) -> FastAPI: ) -> None:
app = build_app(args)
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model_names = [args.model] served_model_names = [args.model]
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests: if args.disable_log_requests:
request_logger = None request_logger = None
else: else:
request_logger = RequestLogger(max_log_len=args.max_log_len) request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat state.engine_client = async_engine_client
global openai_serving_completion state.log_stats = not args.disable_log_stats
global openai_serving_embedding
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat( state.openai_serving_chat = OpenAIServingChat(
async_engine_client, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
@ -465,7 +496,7 @@ async def init_app(
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser) tool_parser=args.tool_call_parser)
openai_serving_completion = OpenAIServingCompletion( state.openai_serving_completion = OpenAIServingCompletion(
async_engine_client, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
@ -474,13 +505,13 @@ async def init_app(
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) )
openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
request_logger=request_logger, request_logger=request_logger,
) )
openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
@ -488,25 +519,31 @@ async def init_app(
request_logger=request_logger, request_logger=request_logger,
chat_template=args.chat_template, chat_template=args.chat_template,
) )
app.root_path = args.root_path
return app
async def run_server(args, **uvicorn_kwargs) -> None: async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as async_engine_client: async with build_async_engine_client(args) as async_engine_client:
# If None, creation of the client failed and we exit. # If None, creation of the client failed and we exit.
if async_engine_client is None: if async_engine_client is None:
return return
app = await init_app(async_engine_client, args) app = build_app(args)
model_config = await async_engine_client.get_model_config()
init_app_state(async_engine_client, model_config, app.state, args)
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
engine=async_engine_client, limit_concurrency=async_engine_client.limit_concurrency,
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level=args.uvicorn_log_level, log_level=args.uvicorn_log_level,
@ -530,4 +567,4 @@ if __name__ == "__main__":
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run_server(args)) uvloop.run(run_server(args))

View File

@ -46,7 +46,6 @@ class AsyncEngineRPCServer:
"""Cleanup all resources.""" """Cleanup all resources."""
self.socket.close() self.socket.close()
self.context.destroy() self.context.destroy()
self.engine.shutdown_background_loop()
# Clear the engine reference so that it can be GC'ed. # Clear the engine reference so that it can be GC'ed.
del self.engine del self.engine
@ -233,5 +232,12 @@ async def run_server(server: AsyncEngineRPCServer):
def run_rpc_server(async_engine_args: AsyncEngineArgs, def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str): usage_context: UsageContext, rpc_path: str):
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("AsyncEngineRPCServer terminated")
signal.signal(signal.SIGTERM, signal_handler)
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
uvloop.run(run_server(server)) uvloop.run(run_server(server))

View File

@ -1,8 +1,5 @@
import asyncio import asyncio
import os import os
import signal
import threading
import weakref
from functools import partial from functools import partial
from typing import Any, List, Optional from typing import Any, List, Optional
@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly # Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well # sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker( self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method) distributed_init_method=distributed_init_method)
self._run_workers("init_device") self._run_workers("init_device")

View File

@ -120,7 +120,8 @@ class WorkerMonitor(threading.Thread):
logger.error("Worker %s pid %s died, exit code: %s", logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode) process.name, process.pid, process.exitcode)
# Cleanup any remaining workers # Cleanup any remaining workers
logger.info("Killing local vLLM worker processes") if logger:
logger.info("Killing local vLLM worker processes")
for worker in self.workers: for worker in self.workers:
worker.kill_worker() worker.kill_worker()
# Must be done after worker task queues are all closed # Must be done after worker task queues are all closed
@ -221,6 +222,8 @@ def _run_worker_process(
try: try:
executor = getattr(worker, method) executor = getattr(worker, method)
output = executor(*args, **kwargs) output = executor(*args, **kwargs)
except KeyboardInterrupt:
break
except BaseException as e: except BaseException as e:
tb = traceback.format_exc() tb = traceback.format_exc()
logger.error( logger.error(

View File

@ -26,6 +26,8 @@ logger = init_logger(__name__)
class RayTPUExecutor(TPUExecutor): class RayTPUExecutor(TPUExecutor):
uses_ray: bool = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running # This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case. # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.

View File

@ -1,11 +1,11 @@
# The CLI entrypoint to vLLM. # The CLI entrypoint to vLLM.
import argparse import argparse
import asyncio
import os import os
import signal import signal
import sys import sys
from typing import List, Optional from typing import List, Optional
import uvloop
from openai import OpenAI from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model. # EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag args.model = args.model_tag
asyncio.run(run_server(args)) uvloop.run(run_server(args))
def interactive_cli(args: argparse.Namespace) -> None: def interactive_cli(args: argparse.Namespace) -> None:

View File

@ -12,6 +12,7 @@ import tempfile
import threading import threading
import uuid import uuid
import warnings import warnings
import weakref
from asyncio import FIRST_COMPLETED, ensure_future from asyncio import FIRST_COMPLETED, ensure_future
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
"""Make an instance method that weakly references
its associated instance and no-ops once that
instance is collected."""
ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
unbound = bound_method.__func__ # type: ignore[attr-defined]
def weak_bound(*args, **kwargs) -> None:
if inst := ref():
unbound(inst, *args, **kwargs)
return weak_bound
#From: https://stackoverflow.com/a/4104188/2749989 #From: https://stackoverflow.com/a/4104188/2749989
def run_once(f: Callable[P, None]) -> Callable[P, None]: def run_once(f: Callable[P, None]) -> Callable[P, None]: