mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 05:03:04 +08:00
[P/D] Add a shutdown method to the Connector API (#22699)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
8c892b1831
commit
61aa4b2901
@ -2,11 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
|
||||
has_kv_transfer_group, is_v1_kv_transfer_group)
|
||||
KVConnectorBaseType, ensure_kv_transfer_initialized,
|
||||
ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
|
||||
__all__ = [
|
||||
"get_kv_transfer_group", "has_kv_transfer_group",
|
||||
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
|
||||
"KVConnectorBaseType"
|
||||
"ensure_kv_transfer_shutdown", "KVConnectorBaseType"
|
||||
]
|
||||
|
||||
@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
config=vllm_config, role=KVConnectorRole.WORKER)
|
||||
else:
|
||||
raise ValueError("V0 is no longer supported")
|
||||
|
||||
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
global _KV_CONNECTOR_AGENT
|
||||
if _KV_CONNECTOR_AGENT is not None:
|
||||
_KV_CONNECTOR_AGENT.shutdown()
|
||||
_KV_CONNECTOR_AGENT = None
|
||||
|
||||
@ -231,7 +231,7 @@ class ExecutorBase(ABC):
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the executor."""
|
||||
return
|
||||
self.collective_rpc("shutdown")
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
@ -1188,6 +1188,8 @@ class Scheduler(SchedulerInterface):
|
||||
def shutdown(self) -> None:
|
||||
if self.kv_event_publisher:
|
||||
self.kv_event_publisher.shutdown()
|
||||
if self.connector is not None:
|
||||
self.connector.shutdown()
|
||||
|
||||
########################################################################
|
||||
# KV Connector Related Methods
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
@ -507,6 +506,7 @@ class WorkerProc:
|
||||
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||
|
||||
def shutdown(self):
|
||||
self.worker.shutdown()
|
||||
self.rpc_broadcast_mq = None
|
||||
self.worker_response_mq = None
|
||||
destroy_model_parallel()
|
||||
@ -536,7 +536,7 @@ class WorkerProc:
|
||||
# tuple[Connection, Connection]
|
||||
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||
death_pipe = kwargs.pop("death_pipe", None)
|
||||
|
||||
shutdown_event = threading.Event()
|
||||
# Start death monitoring thread if death_pipe is provided
|
||||
if death_pipe is not None:
|
||||
|
||||
@ -548,7 +548,7 @@ class WorkerProc:
|
||||
# Parent process has exited, terminate this worker
|
||||
logger.info("Parent process exited, terminating worker")
|
||||
# Send signal to self to trigger clean shutdown
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
shutdown_event.set()
|
||||
except Exception as e:
|
||||
logger.warning("Death monitoring error: %s", e)
|
||||
|
||||
@ -576,7 +576,7 @@ class WorkerProc:
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
worker.worker_busy_loop()
|
||||
worker.worker_busy_loop(cancel=shutdown_event)
|
||||
|
||||
except Exception:
|
||||
# NOTE: if an Exception arises in busy_loop, we send
|
||||
@ -586,6 +586,8 @@ class WorkerProc:
|
||||
|
||||
if ready_writer is not None:
|
||||
logger.exception("WorkerProc failed to start.")
|
||||
elif shutdown_event.is_set():
|
||||
logger.info("WorkerProc shutting down.")
|
||||
else:
|
||||
logger.exception("WorkerProc failed.")
|
||||
|
||||
@ -637,11 +639,11 @@ class WorkerProc:
|
||||
output = self.async_output_queue.get()
|
||||
self.enqueue_output(output)
|
||||
|
||||
def worker_busy_loop(self):
|
||||
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||
cancel=cancel)
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
func = getattr(self.worker, method)
|
||||
|
||||
@ -601,6 +601,9 @@ class Worker(WorkerBase):
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.model_runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
@ -42,6 +43,11 @@ class KVConnectorModelRunnerMixin:
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@staticmethod
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
if has_kv_transfer_group():
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
@staticmethod
|
||||
def maybe_wait_for_kv_save() -> None:
|
||||
if has_kv_transfer_group():
|
||||
|
||||
@ -330,6 +330,9 @@ class TPUWorker:
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.model_runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
|
||||
if USE_TPU_COMMONS:
|
||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||
|
||||
@ -129,6 +129,10 @@ class WorkerBase:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Clean up resources held by the worker."""
|
||||
return
|
||||
|
||||
|
||||
class DelegateWorkerBase(WorkerBase):
|
||||
"""
|
||||
@ -519,6 +523,10 @@ class WorkerWrapperBase:
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.worker is not None:
|
||||
self.worker.shutdown()
|
||||
|
||||
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
||||
"""
|
||||
Adjust the rpc_rank based on the given mapping.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user