mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:44:57 +08:00
[BugFix] Fix clean shutdown issues (#8492)
This commit is contained in:
parent
837c1968f9
commit
acd5511b6d
@ -26,6 +26,11 @@ class RequestOutput:
|
|||||||
finished: bool = False
|
finished: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockModelConfig:
|
||||||
|
use_async_output_proc = True
|
||||||
|
|
||||||
|
|
||||||
class MockEngine:
|
class MockEngine:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -35,6 +40,7 @@ class MockEngine:
|
|||||||
self.request_id = None
|
self.request_id = None
|
||||||
# Ugly, remove dependency when possible
|
# Ugly, remove dependency when possible
|
||||||
self.parallel_config = ParallelConfig(1, 1, False)
|
self.parallel_config = ParallelConfig(1, 1, False)
|
||||||
|
self.model_config = MockModelConfig()
|
||||||
|
|
||||||
async def step_async(self, virtual_engine):
|
async def step_async(self, virtual_engine):
|
||||||
# PP size is 1, ignore virtual engine
|
# PP size is 1, ignore virtual engine
|
||||||
@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_requests_event():
|
async def test_new_requests_event():
|
||||||
engine = MockAsyncLLMEngine(worker_use_ray=False)
|
engine = MockAsyncLLMEngine()
|
||||||
engine.start_background_loop()
|
engine.start_background_loop()
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.step_calls == 0
|
assert engine.engine.step_calls == 0
|
||||||
@ -113,7 +119,7 @@ async def test_new_requests_event():
|
|||||||
assert engine.engine.add_request_calls == 3
|
assert engine.engine.add_request_calls == 3
|
||||||
assert engine.engine.step_calls == old_step_calls + 1
|
assert engine.engine.step_calls == old_step_calls + 1
|
||||||
|
|
||||||
engine = MockAsyncLLMEngine(worker_use_ray=True)
|
engine = MockAsyncLLMEngine()
|
||||||
assert engine.get_model_config() is not None
|
assert engine.get_model_config() is not None
|
||||||
assert engine.get_tokenizer() is not None
|
assert engine.get_tokenizer() is not None
|
||||||
assert engine.get_decoding_config() is not None
|
assert engine.get_decoding_config() is not None
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
import weakref
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||||
Mapping, Optional, Set, Tuple, Type, Union)
|
Mapping, Optional, Set, Tuple, Type, Union)
|
||||||
|
from weakref import ReferenceType
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||||
@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import weak_bind
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
@ -450,9 +453,6 @@ class AsyncLLMEngine:
|
|||||||
method yields the outputs from the :class:`LLMEngine` to the caller.
|
method yields the outputs from the :class:`LLMEngine` to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
|
||||||
distributed execution. Should be the same as
|
|
||||||
`parallel_config.worker_use_ray`.
|
|
||||||
log_requests: Whether to log the requests.
|
log_requests: Whether to log the requests.
|
||||||
start_engine_loop: If True, the background task to run the engine
|
start_engine_loop: If True, the background task to run the engine
|
||||||
will be automatically started in the generate call.
|
will be automatically started in the generate call.
|
||||||
@ -463,23 +463,22 @@ class AsyncLLMEngine:
|
|||||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
worker_use_ray: bool,
|
|
||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
self.engine = self._engine_class(*args, **kwargs)
|
self.engine = self._engine_class(*args, **kwargs)
|
||||||
|
|
||||||
# This ensures quick processing of request outputs
|
# This ensures quick processing of request outputs
|
||||||
# so the append to asyncio queues is not delayed,
|
# so the append to asyncio queues is not delayed,
|
||||||
# especially for multi-step.
|
# especially for multi-step.
|
||||||
#
|
self.use_process_request_outputs_callback = (
|
||||||
self.use_process_request_outputs_callback = True
|
self.engine.model_config.use_async_output_proc)
|
||||||
|
|
||||||
if self.use_process_request_outputs_callback:
|
if self.use_process_request_outputs_callback:
|
||||||
self.engine.process_request_outputs_callback = \
|
self.engine.process_request_outputs_callback = \
|
||||||
self.process_request_outputs
|
weak_bind(self.process_request_outputs)
|
||||||
|
|
||||||
self.background_loop: Optional[asyncio.Future] = None
|
self.background_loop: Optional[asyncio.Future] = None
|
||||||
# We need to keep a reference to unshielded
|
# We need to keep a reference to unshielded
|
||||||
@ -492,6 +491,11 @@ class AsyncLLMEngine:
|
|||||||
# Lazy initialized fields
|
# Lazy initialized fields
|
||||||
self._request_tracker: RequestTracker
|
self._request_tracker: RequestTracker
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if rt := getattr(self, "request_tracker", None):
|
||||||
|
# Wake up engine loop so that it will exit cleanly
|
||||||
|
rt.new_requests_event.set()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_executor_cls(
|
def _get_executor_cls(
|
||||||
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
|
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
|
||||||
@ -502,15 +506,12 @@ class AsyncLLMEngine:
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
"distributed_executor_backend must be a subclass of "
|
"distributed_executor_backend must be a subclass of "
|
||||||
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
|
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
|
||||||
if distributed_executor_backend.uses_ray: # type: ignore
|
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
|
||||||
executor_class = distributed_executor_backend
|
executor_class = distributed_executor_backend
|
||||||
elif engine_config.device_config.device_type == "neuron":
|
elif engine_config.device_config.device_type == "neuron":
|
||||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||||
executor_class = NeuronExecutorAsync
|
executor_class = NeuronExecutorAsync
|
||||||
elif engine_config.device_config.device_type == "tpu":
|
elif engine_config.device_config.device_type == "tpu":
|
||||||
if distributed_executor_backend == "ray":
|
if distributed_executor_backend == "ray":
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
|
||||||
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
|
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
|
||||||
executor_class = RayTPUExecutorAsync
|
executor_class = RayTPUExecutorAsync
|
||||||
else:
|
else:
|
||||||
@ -531,11 +532,9 @@ class AsyncLLMEngine:
|
|||||||
from vllm.executor.xpu_executor import XPUExecutorAsync
|
from vllm.executor.xpu_executor import XPUExecutorAsync
|
||||||
executor_class = XPUExecutorAsync
|
executor_class = XPUExecutorAsync
|
||||||
elif distributed_executor_backend == "ray":
|
elif distributed_executor_backend == "ray":
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
|
||||||
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
|
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
|
||||||
executor_class = RayXPUExecutorAsync
|
executor_class = RayXPUExecutorAsync
|
||||||
elif distributed_executor_backend == "mp":
|
elif distributed_executor_backend == "mp":
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
|
||||||
from vllm.executor.multiproc_xpu_executor import (
|
from vllm.executor.multiproc_xpu_executor import (
|
||||||
MultiprocessingXPUExecutorAsync)
|
MultiprocessingXPUExecutorAsync)
|
||||||
executor_class = MultiprocessingXPUExecutorAsync
|
executor_class = MultiprocessingXPUExecutorAsync
|
||||||
@ -543,7 +542,6 @@ class AsyncLLMEngine:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Not supported distributed execution model on XPU device.")
|
"Not supported distributed execution model on XPU device.")
|
||||||
elif distributed_executor_backend == "ray":
|
elif distributed_executor_backend == "ray":
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||||
executor_class = RayGPUExecutorAsync
|
executor_class = RayGPUExecutorAsync
|
||||||
elif distributed_executor_backend == "mp":
|
elif distributed_executor_backend == "mp":
|
||||||
@ -559,19 +557,23 @@ class AsyncLLMEngine:
|
|||||||
def from_engine_args(
|
def from_engine_args(
|
||||||
cls,
|
cls,
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
|
engine_config: Optional[EngineConfig] = None,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||||
) -> "AsyncLLMEngine":
|
) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_config = engine_args.create_engine_config()
|
if engine_config is None:
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
|
||||||
executor_class = cls._get_executor_cls(engine_config)
|
executor_class = cls._get_executor_cls(engine_config)
|
||||||
|
|
||||||
|
if executor_class.uses_ray:
|
||||||
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
|
|
||||||
# Create the async LLM engine.
|
# Create the async LLM engine.
|
||||||
engine = cls(
|
engine = cls(
|
||||||
executor_class.uses_ray,
|
|
||||||
**engine_config.to_dict(),
|
**engine_config.to_dict(),
|
||||||
executor_class=executor_class,
|
executor_class=executor_class,
|
||||||
log_requests=not engine_args.disable_log_requests,
|
log_requests=not engine_args.disable_log_requests,
|
||||||
@ -628,7 +630,7 @@ class AsyncLLMEngine:
|
|||||||
self._request_tracker = RequestTracker()
|
self._request_tracker = RequestTracker()
|
||||||
|
|
||||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||||
).create_task(self.run_engine_loop())
|
).create_task(self.run_engine_loop(weakref.ref(self)))
|
||||||
self._background_loop_unshielded.add_done_callback(
|
self._background_loop_unshielded.add_done_callback(
|
||||||
partial(_log_task_completion, error_callback=self._error_callback))
|
partial(_log_task_completion, error_callback=self._error_callback))
|
||||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||||
@ -698,9 +700,16 @@ class AsyncLLMEngine:
|
|||||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
self.engine.abort_request(request_ids)
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
async def run_engine_loop(self):
|
@staticmethod
|
||||||
|
async def run_engine_loop(engine_ref: ReferenceType):
|
||||||
|
"""We use a weakref to the engine so that the running loop
|
||||||
|
doesn't prevent the engine being garbage collected."""
|
||||||
|
engine: Optional["AsyncLLMEngine"] = engine_ref()
|
||||||
|
if not engine:
|
||||||
|
return
|
||||||
|
|
||||||
pipeline_parallel_size = \
|
pipeline_parallel_size = \
|
||||||
self.engine.parallel_config.pipeline_parallel_size
|
engine.engine.parallel_config.pipeline_parallel_size
|
||||||
has_requests_in_progress = [False] * pipeline_parallel_size
|
has_requests_in_progress = [False] * pipeline_parallel_size
|
||||||
while True:
|
while True:
|
||||||
if not any(has_requests_in_progress):
|
if not any(has_requests_in_progress):
|
||||||
@ -711,11 +720,21 @@ class AsyncLLMEngine:
|
|||||||
# timeout, and unblocks the RPC thread in the workers so that
|
# timeout, and unblocks the RPC thread in the workers so that
|
||||||
# they can process any other queued control plane messages,
|
# they can process any other queued control plane messages,
|
||||||
# such as add/remove lora adapters.
|
# such as add/remove lora adapters.
|
||||||
await self.engine.stop_remote_worker_execution_loop_async()
|
await engine.engine.stop_remote_worker_execution_loop_async()
|
||||||
await self._request_tracker.wait_for_new_requests()
|
request_tracker = engine._request_tracker
|
||||||
|
# Allow engine to be garbage collected while
|
||||||
|
# waiting for new requests
|
||||||
|
del engine
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if engine_ref() is None:
|
||||||
|
return
|
||||||
|
await request_tracker.wait_for_new_requests()
|
||||||
|
engine = engine_ref()
|
||||||
|
if not engine:
|
||||||
|
return
|
||||||
logger.debug("Got new requests!")
|
logger.debug("Got new requests!")
|
||||||
requests_in_progress = [
|
requests_in_progress = [
|
||||||
asyncio.create_task(self.engine_step(ve))
|
asyncio.create_task(engine.engine_step(ve))
|
||||||
for ve in range(pipeline_parallel_size)
|
for ve in range(pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
has_requests_in_progress = [True] * pipeline_parallel_size
|
has_requests_in_progress = [True] * pipeline_parallel_size
|
||||||
@ -733,19 +752,20 @@ class AsyncLLMEngine:
|
|||||||
result = task.result()
|
result = task.result()
|
||||||
virtual_engine = requests_in_progress.index(task)
|
virtual_engine = requests_in_progress.index(task)
|
||||||
has_unfinished_requests = (
|
has_unfinished_requests = (
|
||||||
self.engine.has_unfinished_requests_for_virtual_engine(
|
engine.engine.
|
||||||
|
has_unfinished_requests_for_virtual_engine(
|
||||||
virtual_engine))
|
virtual_engine))
|
||||||
if result or has_unfinished_requests:
|
if result or has_unfinished_requests:
|
||||||
requests_in_progress[virtual_engine] = (
|
requests_in_progress[virtual_engine] = (
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.engine_step(virtual_engine)))
|
engine.engine_step(virtual_engine)))
|
||||||
has_requests_in_progress[virtual_engine] = True
|
has_requests_in_progress[virtual_engine] = True
|
||||||
else:
|
else:
|
||||||
has_requests_in_progress[virtual_engine] = False
|
has_requests_in_progress[virtual_engine] = False
|
||||||
except asyncio.TimeoutError as exc:
|
except asyncio.TimeoutError as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Engine iteration timed out. This should never happen!")
|
"Engine iteration timed out. This should never happen!")
|
||||||
self.set_errored(exc)
|
engine.set_errored(exc)
|
||||||
raise
|
raise
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import functools
|
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||||
Iterable, List, Mapping, NamedTuple, Optional)
|
Iterable, List, Mapping, NamedTuple, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
|||||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||||
usage_message)
|
usage_message)
|
||||||
from vllm.utils import Counter, Device
|
from vllm.utils import Counter, Device, weak_bind
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -382,11 +382,16 @@ class LLMEngine:
|
|||||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.async_callbacks = [
|
if model_config.use_async_output_proc:
|
||||||
functools.partial(self._process_model_outputs,
|
process_model_outputs = weak_bind(self._process_model_outputs)
|
||||||
ctx=self.scheduler_contexts[v_id])
|
|
||||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
self.async_callbacks = [
|
||||||
]
|
partial(process_model_outputs,
|
||||||
|
ctx=self.scheduler_contexts[v_id])
|
||||||
|
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.async_callbacks = []
|
||||||
|
|
||||||
# Currently used by AsyncLLMEngine to ensure quick append
|
# Currently used by AsyncLLMEngine to ensure quick append
|
||||||
# of request outputs to asyncio queues
|
# of request outputs to asyncio queues
|
||||||
@ -869,8 +874,8 @@ class LLMEngine:
|
|||||||
"""
|
"""
|
||||||
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _process_sequence_group_outputs(
|
def _process_sequence_group_outputs(
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
outputs: List[EmbeddingSequenceGroupOutput],
|
outputs: List[EmbeddingSequenceGroupOutput],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -1,21 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Request, Response
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_process_using_port
|
from vllm.utils import find_process_using_port
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
||||||
**uvicorn_kwargs: Any):
|
**uvicorn_kwargs: Any):
|
||||||
logger.info("Available routes are:")
|
logger.info("Available routes are:")
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
|||||||
|
|
||||||
# Set concurrency limits in uvicorn if running in multiprocessing mode
|
# Set concurrency limits in uvicorn if running in multiprocessing mode
|
||||||
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
|
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
|
||||||
if engine.limit_concurrency is not None:
|
if limit_concurrency is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
|
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
|
||||||
"limit at the expense of performance run with "
|
"limit at the expense of performance run with "
|
||||||
"--disable-frontend-multiprocessing", engine.limit_concurrency)
|
"--disable-frontend-multiprocessing", limit_concurrency)
|
||||||
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
|
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
|
||||||
|
|
||||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
_add_shutdown_handlers(app, server, engine)
|
_add_shutdown_handlers(app, server)
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
@ -68,15 +67,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
|||||||
return server.shutdown()
|
return server.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
|
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||||
engine: AsyncEngineClient) -> None:
|
|
||||||
"""Adds handlers for fatal errors that should crash the server"""
|
"""Adds handlers for fatal errors that should crash the server"""
|
||||||
|
|
||||||
@app.exception_handler(RuntimeError)
|
@app.exception_handler(RuntimeError)
|
||||||
async def runtime_error_handler(_, __):
|
async def runtime_error_handler(request: Request, __):
|
||||||
"""On generic runtime error, check to see if the engine has died.
|
"""On generic runtime error, check to see if the engine has died.
|
||||||
It probably has, in which case the server will no longer be able to
|
It probably has, in which case the server will no longer be able to
|
||||||
handle requests. Trigger a graceful shutdown with a SIGTERM."""
|
handle requests. Trigger a graceful shutdown with a SIGTERM."""
|
||||||
|
engine = request.app.state.engine_client
|
||||||
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
|
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
|
||||||
and not engine.is_running):
|
and not engine.is_running):
|
||||||
logger.fatal("AsyncLLMEngine has failed, terminating server "
|
logger.fatal("AsyncLLMEngine has failed, terminating server "
|
||||||
|
|||||||
@ -4,16 +4,20 @@ import inspect
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import signal
|
||||||
import tempfile
|
import tempfile
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import AsyncIterator, Optional, Set
|
from typing import AsyncIterator, Optional, Set
|
||||||
|
|
||||||
|
import uvloop
|
||||||
from fastapi import APIRouter, FastAPI, Request
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from starlette.datastructures import State
|
||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
@ -54,12 +58,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
|||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
async_engine_client: AsyncEngineClient
|
|
||||||
engine_args: AsyncEngineArgs
|
|
||||||
openai_serving_chat: OpenAIServingChat
|
|
||||||
openai_serving_completion: OpenAIServingCompletion
|
|
||||||
openai_serving_embedding: OpenAIServingEmbedding
|
|
||||||
openai_serving_tokenization: OpenAIServingTokenization
|
|
||||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
|
|
||||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
try:
|
||||||
|
if app.state.log_stats:
|
||||||
|
async_engine_client = app.state.engine_client
|
||||||
|
|
||||||
async def _force_log():
|
async def _force_log():
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await async_engine_client.do_log_stats()
|
await async_engine_client.do_log_stats()
|
||||||
|
|
||||||
if not engine_args.disable_log_stats:
|
task = asyncio.create_task(_force_log())
|
||||||
task = asyncio.create_task(_force_log())
|
_running_tasks.add(task)
|
||||||
_running_tasks.add(task)
|
task.add_done_callback(_running_tasks.remove)
|
||||||
task.add_done_callback(_running_tasks.remove)
|
else:
|
||||||
|
task = None
|
||||||
yield
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if task is not None:
|
||||||
|
task.cancel()
|
||||||
|
finally:
|
||||||
|
# Ensure app state including engine ref is gc'd
|
||||||
|
del app.state
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -103,16 +111,10 @@ async def build_async_engine_client(
|
|||||||
|
|
||||||
# Context manager to handle async_engine_client lifecycle
|
# Context manager to handle async_engine_client lifecycle
|
||||||
# Ensures everything is shutdown and cleaned up on error/exit
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
global engine_args
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
# Backend itself still global for the silly lil' health handler
|
|
||||||
global async_engine_client
|
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||||
|
|
||||||
async_engine_client = engine # type: ignore[assignment]
|
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
||||||
engine_args.quantization, engine_args.revision)
|
engine_args.quantization, engine_args.revision)
|
||||||
or disable_frontend_multiprocessing):
|
or disable_frontend_multiprocessing):
|
||||||
engine_client = AsyncLLMEngine.from_engine_args(
|
engine_config = engine_args.create_engine_config()
|
||||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||||
try:
|
"uses_ray", False)
|
||||||
yield engine_client
|
|
||||||
finally:
|
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||||
engine_client.shutdown_background_loop()
|
engine_args=engine_args,
|
||||||
|
engine_config=engine_config,
|
||||||
|
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||||
|
if uses_ray:
|
||||||
|
# Must run in main thread with ray for its signal handlers to work
|
||||||
|
engine_client = build_engine()
|
||||||
|
else:
|
||||||
|
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, build_engine)
|
||||||
|
|
||||||
|
yield engine_client
|
||||||
return
|
return
|
||||||
|
|
||||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||||
@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI):
|
|||||||
app.routes.append(metrics_route)
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
|
def chat(request: Request) -> OpenAIServingChat:
|
||||||
|
return request.app.state.openai_serving_chat
|
||||||
|
|
||||||
|
|
||||||
|
def completion(request: Request) -> OpenAIServingCompletion:
|
||||||
|
return request.app.state.openai_serving_completion
|
||||||
|
|
||||||
|
|
||||||
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||||
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> AsyncEngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health")
|
@router.get("/health")
|
||||||
async def health() -> Response:
|
async def health(raw_request: Request) -> Response:
|
||||||
"""Health check."""
|
"""Health check."""
|
||||||
await async_engine_client.check_health()
|
await engine_client(raw_request).check_health()
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tokenize")
|
@router.post("/tokenize")
|
||||||
async def tokenize(request: TokenizeRequest):
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
generator = await openai_serving_tokenization.create_tokenize(request)
|
generator = await tokenization(raw_request).create_tokenize(request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/detokenize")
|
@router.post("/detokenize")
|
||||||
async def detokenize(request: DetokenizeRequest):
|
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||||
generator = await openai_serving_tokenization.create_detokenize(request)
|
generator = await tokenization(raw_request).create_detokenize(request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
async def show_available_models():
|
async def show_available_models(raw_request: Request):
|
||||||
models = await openai_serving_completion.show_available_models()
|
models = await completion(raw_request).show_available_models()
|
||||||
return JSONResponse(content=models.model_dump())
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@ -288,7 +320,7 @@ async def show_version():
|
|||||||
async def create_chat_completion(request: ChatCompletionRequest,
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
|
|
||||||
generator = await openai_serving_chat.create_chat_completion(
|
generator = await chat(raw_request).create_chat_completion(
|
||||||
request, raw_request)
|
request, raw_request)
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
|
|
||||||
@router.post("/v1/completions")
|
@router.post("/v1/completions")
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
generator = await openai_serving_completion.create_completion(
|
generator = await completion(raw_request).create_completion(
|
||||||
request, raw_request)
|
request, raw_request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/v1/embeddings")
|
@router.post("/v1/embeddings")
|
||||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
generator = await openai_serving_embedding.create_embedding(
|
generator = await embedding(raw_request).create_embedding(
|
||||||
request, raw_request)
|
request, raw_request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
@ -333,16 +365,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
|
|||||||
"used for local development!")
|
"used for local development!")
|
||||||
|
|
||||||
@router.post("/start_profile")
|
@router.post("/start_profile")
|
||||||
async def start_profile():
|
async def start_profile(raw_request: Request):
|
||||||
logger.info("Starting profiler...")
|
logger.info("Starting profiler...")
|
||||||
await async_engine_client.start_profile()
|
await engine_client(raw_request).start_profile()
|
||||||
logger.info("Profiler started.")
|
logger.info("Profiler started.")
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
@router.post("/stop_profile")
|
@router.post("/stop_profile")
|
||||||
async def stop_profile():
|
async def stop_profile(raw_request: Request):
|
||||||
logger.info("Stopping profiler...")
|
logger.info("Stopping profiler...")
|
||||||
await async_engine_client.stop_profile()
|
await engine_client(raw_request).stop_profile()
|
||||||
logger.info("Profiler stopped.")
|
logger.info("Profiler stopped.")
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
@ -353,13 +385,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||||||
"This should ONLY be used for local development!")
|
"This should ONLY be used for local development!")
|
||||||
|
|
||||||
@router.post("/v1/load_lora_adapter")
|
@router.post("/v1/load_lora_adapter")
|
||||||
async def load_lora_adapter(request: LoadLoraAdapterRequest):
|
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||||
response = await openai_serving_chat.load_lora_adapter(request)
|
raw_request: Request):
|
||||||
|
response = await chat(raw_request).load_lora_adapter(request)
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
status_code=response.code)
|
status_code=response.code)
|
||||||
|
|
||||||
response = await openai_serving_completion.load_lora_adapter(request)
|
response = await completion(raw_request).load_lora_adapter(request)
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
status_code=response.code)
|
status_code=response.code)
|
||||||
@ -367,13 +400,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||||||
return Response(status_code=200, content=response)
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
@router.post("/v1/unload_lora_adapter")
|
@router.post("/v1/unload_lora_adapter")
|
||||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
|
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||||
response = await openai_serving_chat.unload_lora_adapter(request)
|
raw_request: Request):
|
||||||
|
response = await chat(raw_request).unload_lora_adapter(request)
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
status_code=response.code)
|
status_code=response.code)
|
||||||
|
|
||||||
response = await openai_serving_completion.unload_lora_adapter(request)
|
response = await completion(raw_request).unload_lora_adapter(request)
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
status_code=response.code)
|
status_code=response.code)
|
||||||
@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(_, exc):
|
async def validation_exception_handler(_, exc):
|
||||||
err = openai_serving_chat.create_error_response(message=str(exc))
|
chat = app.state.openai_serving_chat
|
||||||
|
err = chat.create_error_response(message=str(exc))
|
||||||
return JSONResponse(err.model_dump(),
|
return JSONResponse(err.model_dump(),
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
@ -430,30 +465,26 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
async def init_app(
|
def init_app_state(
|
||||||
async_engine_client: AsyncEngineClient,
|
async_engine_client: AsyncEngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
state: State,
|
||||||
args: Namespace,
|
args: Namespace,
|
||||||
) -> FastAPI:
|
) -> None:
|
||||||
app = build_app(args)
|
|
||||||
|
|
||||||
if args.served_model_name is not None:
|
if args.served_model_name is not None:
|
||||||
served_model_names = args.served_model_name
|
served_model_names = args.served_model_name
|
||||||
else:
|
else:
|
||||||
served_model_names = [args.model]
|
served_model_names = [args.model]
|
||||||
|
|
||||||
model_config = await async_engine_client.get_model_config()
|
|
||||||
|
|
||||||
if args.disable_log_requests:
|
if args.disable_log_requests:
|
||||||
request_logger = None
|
request_logger = None
|
||||||
else:
|
else:
|
||||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
global openai_serving_chat
|
state.engine_client = async_engine_client
|
||||||
global openai_serving_completion
|
state.log_stats = not args.disable_log_stats
|
||||||
global openai_serving_embedding
|
|
||||||
global openai_serving_tokenization
|
|
||||||
|
|
||||||
openai_serving_chat = OpenAIServingChat(
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
async_engine_client,
|
async_engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
@ -465,7 +496,7 @@ async def init_app(
|
|||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
tool_parser=args.tool_call_parser)
|
tool_parser=args.tool_call_parser)
|
||||||
openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
async_engine_client,
|
async_engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
@ -474,13 +505,13 @@ async def init_app(
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
)
|
)
|
||||||
openai_serving_embedding = OpenAIServingEmbedding(
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
async_engine_client,
|
async_engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
)
|
)
|
||||||
openai_serving_tokenization = OpenAIServingTokenization(
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
async_engine_client,
|
async_engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
@ -488,25 +519,31 @@ async def init_app(
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=args.chat_template,
|
chat_template=args.chat_template,
|
||||||
)
|
)
|
||||||
app.root_path = args.root_path
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
logger.info("args: %s", args)
|
logger.info("args: %s", args)
|
||||||
|
|
||||||
|
def signal_handler(*_) -> None:
|
||||||
|
# Interrupt server on sigterm while initializing
|
||||||
|
raise KeyboardInterrupt("terminated")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
async with build_async_engine_client(args) as async_engine_client:
|
async with build_async_engine_client(args) as async_engine_client:
|
||||||
# If None, creation of the client failed and we exit.
|
# If None, creation of the client failed and we exit.
|
||||||
if async_engine_client is None:
|
if async_engine_client is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
app = await init_app(async_engine_client, args)
|
app = build_app(args)
|
||||||
|
|
||||||
|
model_config = await async_engine_client.get_model_config()
|
||||||
|
init_app_state(async_engine_client, model_config, app.state, args)
|
||||||
|
|
||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
engine=async_engine_client,
|
limit_concurrency=async_engine_client.limit_concurrency,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
log_level=args.uvicorn_log_level,
|
log_level=args.uvicorn_log_level,
|
||||||
@ -530,4 +567,4 @@ if __name__ == "__main__":
|
|||||||
parser = make_arg_parser(parser)
|
parser = make_arg_parser(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
asyncio.run(run_server(args))
|
uvloop.run(run_server(args))
|
||||||
|
|||||||
@ -46,7 +46,6 @@ class AsyncEngineRPCServer:
|
|||||||
"""Cleanup all resources."""
|
"""Cleanup all resources."""
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
self.context.destroy()
|
self.context.destroy()
|
||||||
self.engine.shutdown_background_loop()
|
|
||||||
# Clear the engine reference so that it can be GC'ed.
|
# Clear the engine reference so that it can be GC'ed.
|
||||||
del self.engine
|
del self.engine
|
||||||
|
|
||||||
@ -233,5 +232,12 @@ async def run_server(server: AsyncEngineRPCServer):
|
|||||||
|
|
||||||
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
||||||
usage_context: UsageContext, rpc_path: str):
|
usage_context: UsageContext, rpc_path: str):
|
||||||
|
|
||||||
|
def signal_handler(*_) -> None:
|
||||||
|
# Interrupt server on sigterm while initializing
|
||||||
|
raise KeyboardInterrupt("AsyncEngineRPCServer terminated")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
|
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
|
||||||
uvloop.run(run_server(server))
|
uvloop.run(run_server(server))
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import threading
|
|
||||||
import weakref
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
# Set up signal handlers to shutdown the executor cleanly
|
# Set up signal handlers to shutdown the executor cleanly
|
||||||
# sometimes gc does not work well
|
# sometimes gc does not work well
|
||||||
|
|
||||||
# Use weakref to avoid holding a reference to self
|
|
||||||
ref = weakref.ref(self)
|
|
||||||
|
|
||||||
def shutdown(signum, frame):
|
|
||||||
if executor := ref():
|
|
||||||
executor.shutdown()
|
|
||||||
|
|
||||||
if threading.current_thread() is threading.main_thread():
|
|
||||||
signal.signal(signal.SIGINT, shutdown)
|
|
||||||
signal.signal(signal.SIGTERM, shutdown)
|
|
||||||
|
|
||||||
self.driver_worker = self._create_worker(
|
self.driver_worker = self._create_worker(
|
||||||
distributed_init_method=distributed_init_method)
|
distributed_init_method=distributed_init_method)
|
||||||
self._run_workers("init_device")
|
self._run_workers("init_device")
|
||||||
|
|||||||
@ -120,7 +120,8 @@ class WorkerMonitor(threading.Thread):
|
|||||||
logger.error("Worker %s pid %s died, exit code: %s",
|
logger.error("Worker %s pid %s died, exit code: %s",
|
||||||
process.name, process.pid, process.exitcode)
|
process.name, process.pid, process.exitcode)
|
||||||
# Cleanup any remaining workers
|
# Cleanup any remaining workers
|
||||||
logger.info("Killing local vLLM worker processes")
|
if logger:
|
||||||
|
logger.info("Killing local vLLM worker processes")
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
worker.kill_worker()
|
worker.kill_worker()
|
||||||
# Must be done after worker task queues are all closed
|
# Must be done after worker task queues are all closed
|
||||||
@ -221,6 +222,8 @@ def _run_worker_process(
|
|||||||
try:
|
try:
|
||||||
executor = getattr(worker, method)
|
executor = getattr(worker, method)
|
||||||
output = executor(*args, **kwargs)
|
output = executor(*args, **kwargs)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
break
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
tb = traceback.format_exc()
|
tb = traceback.format_exc()
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|||||||
@ -26,6 +26,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class RayTPUExecutor(TPUExecutor):
|
class RayTPUExecutor(TPUExecutor):
|
||||||
|
|
||||||
|
uses_ray: bool = True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# This is non-None when the execute model loop is running
|
# This is non-None when the execute model loop is running
|
||||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
# The CLI entrypoint to vLLM.
|
# The CLI entrypoint to vLLM.
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import uvloop
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None:
|
|||||||
# EngineArgs expects the model name to be passed as --model.
|
# EngineArgs expects the model name to be passed as --model.
|
||||||
args.model = args.model_tag
|
args.model = args.model_tag
|
||||||
|
|
||||||
asyncio.run(run_server(args))
|
uvloop.run(run_server(args))
|
||||||
|
|
||||||
|
|
||||||
def interactive_cli(args: argparse.Namespace) -> None:
|
def interactive_cli(args: argparse.Namespace) -> None:
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
import weakref
|
||||||
from asyncio import FIRST_COMPLETED, ensure_future
|
from asyncio import FIRST_COMPLETED, ensure_future
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from platform import uname
|
from platform import uname
|
||||||
@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int:
|
|||||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||||
|
|
||||||
|
|
||||||
|
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
|
||||||
|
"""Make an instance method that weakly references
|
||||||
|
its associated instance and no-ops once that
|
||||||
|
instance is collected."""
|
||||||
|
ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
|
||||||
|
unbound = bound_method.__func__ # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
def weak_bound(*args, **kwargs) -> None:
|
||||||
|
if inst := ref():
|
||||||
|
unbound(inst, *args, **kwargs)
|
||||||
|
|
||||||
|
return weak_bound
|
||||||
|
|
||||||
|
|
||||||
#From: https://stackoverflow.com/a/4104188/2749989
|
#From: https://stackoverflow.com/a/4104188/2749989
|
||||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user