diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index fa9b7e4f14c02..cf58e7914972c 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_transfer_state import ( - KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, - has_kv_transfer_group, is_v1_kv_transfer_group) + KVConnectorBaseType, ensure_kv_transfer_initialized, + ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group, + is_v1_kv_transfer_group) __all__ = [ "get_kv_transfer_group", "has_kv_transfer_group", "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", - "KVConnectorBaseType" + "ensure_kv_transfer_shutdown", "KVConnectorBaseType" ] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2804003f5a708..f3f493144d283 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC): """ return None, None + def shutdown(self): + """ + Shutdown the connector. This is called when the worker process + is shutting down to ensure that all the async operations are + completed and the connector is cleaned up properly. + """ + return None + # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 5e0f64fca220c..d5747bed92771 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: config=vllm_config, role=KVConnectorRole.WORKER) else: raise ValueError("V0 is no longer supported") + + +def ensure_kv_transfer_shutdown() -> None: + global _KV_CONNECTOR_AGENT + if _KV_CONNECTOR_AGENT is not None: + _KV_CONNECTOR_AGENT.shutdown() + _KV_CONNECTOR_AGENT = None diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 813232cd19281..a3c1d79a58b26 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -231,7 +231,7 @@ class ExecutorBase(ABC): def shutdown(self) -> None: """Shutdown the executor.""" - return + self.collective_rpc("shutdown") def __del__(self): self.shutdown() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 31f7e9c70f8b3..2d40e96632c95 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1188,6 +1188,8 @@ class Scheduler(SchedulerInterface): def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() ######################################################################## # KV Connector Related Methods diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index ef6303495c245..c3d6c20e22e2a 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing -import os import pickle import queue import signal @@ -507,6 +506,7 @@ class WorkerProc: return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): + self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() @@ -536,7 +536,7 @@ class WorkerProc: # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") death_pipe = kwargs.pop("death_pipe", None) - + shutdown_event = threading.Event() # Start death monitoring thread if death_pipe is provided if death_pipe is not None: @@ -548,7 +548,7 @@ class WorkerProc: # Parent process has exited, terminate this worker logger.info("Parent process exited, terminating worker") # Send signal to self to trigger clean shutdown - os.kill(os.getpid(), signal.SIGTERM) + shutdown_event.set() except Exception as e: logger.warning("Death monitoring error: %s", e) @@ -576,7 +576,7 @@ class WorkerProc: ready_writer.close() ready_writer = None - worker.worker_busy_loop() + worker.worker_busy_loop(cancel=shutdown_event) except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -586,6 +586,8 @@ class WorkerProc: if ready_writer is not None: logger.exception("WorkerProc failed to start.") + elif shutdown_event.is_set(): + logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -637,11 +639,11 @@ class WorkerProc: output = self.async_output_queue.get() self.enqueue_output(output) - def worker_busy_loop(self): + def worker_busy_loop(self, cancel: Optional[threading.Event] = None): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() - + method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( + cancel=cancel) try: if isinstance(method, str): func = getattr(self.worker, method) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6a3bc5d46df27..726f59603437f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -601,6 +601,9 @@ class Worker(WorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() + def init_worker_distributed_environment( vllm_config: VllmConfig, diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index e2ffa2f12fda5..67bb967d2edfa 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035 from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, +from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown, + get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.forward_context import get_forward_context, set_forward_context @@ -42,6 +43,11 @@ class KVConnectorModelRunnerMixin: # Do this here to save a collective_rpc. kv_connector.start_load_kv(get_forward_context()) + @staticmethod + def ensure_kv_transfer_shutdown() -> None: + if has_kv_transfer_group(): + ensure_kv_transfer_shutdown() + @staticmethod def maybe_wait_for_kv_save() -> None: if has_kv_transfer_group(): diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 3f4e3ecbd4e26..fc72b954df9cf 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -330,6 +330,9 @@ class TPUWorker: ensure_kv_transfer_initialized(vllm_config) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() + if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a1fa7f2cf7a2e..aa76d21f0fcaa 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -129,6 +129,10 @@ class WorkerBase: """Get vocabulary size from model configuration.""" return self.model_config.get_vocab_size() + def shutdown(self) -> None: + """Clean up resources held by the worker.""" + return + class DelegateWorkerBase(WorkerBase): """ @@ -519,6 +523,10 @@ class WorkerWrapperBase: from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + def shutdown(self) -> None: + if self.worker is not None: + self.worker.shutdown() + def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: """ Adjust the rpc_rank based on the given mapping.