[P/D] Add a shutdown method to the Connector API (#22699)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-09-08 14:07:00 +08:00 committed by GitHub
parent 8c892b1831
commit 61aa4b2901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 52 additions and 12 deletions

View File

@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
has_kv_transfer_group, is_v1_kv_transfer_group)
KVConnectorBaseType, ensure_kv_transfer_initialized,
ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group)
__all__ = [
"get_kv_transfer_group", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"KVConnectorBaseType"
"ensure_kv_transfer_shutdown", "KVConnectorBaseType"
]

View File

@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC):
"""
return None, None
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================

View File

@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
config=vllm_config, role=KVConnectorRole.WORKER)
else:
raise ValueError("V0 is no longer supported")
def ensure_kv_transfer_shutdown() -> None:
global _KV_CONNECTOR_AGENT
if _KV_CONNECTOR_AGENT is not None:
_KV_CONNECTOR_AGENT.shutdown()
_KV_CONNECTOR_AGENT = None

View File

@ -231,7 +231,7 @@ class ExecutorBase(ABC):
def shutdown(self) -> None:
"""Shutdown the executor."""
return
self.collective_rpc("shutdown")
def __del__(self):
self.shutdown()

View File

@ -1188,6 +1188,8 @@ class Scheduler(SchedulerInterface):
def shutdown(self) -> None:
if self.kv_event_publisher:
self.kv_event_publisher.shutdown()
if self.connector is not None:
self.connector.shutdown()
########################################################################
# KV Connector Related Methods

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os
import pickle
import queue
import signal
@ -507,6 +506,7 @@ class WorkerProc:
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
self.worker.shutdown()
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
@ -536,7 +536,7 @@ class WorkerProc:
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
death_pipe = kwargs.pop("death_pipe", None)
shutdown_event = threading.Event()
# Start death monitoring thread if death_pipe is provided
if death_pipe is not None:
@ -548,7 +548,7 @@ class WorkerProc:
# Parent process has exited, terminate this worker
logger.info("Parent process exited, terminating worker")
# Send signal to self to trigger clean shutdown
os.kill(os.getpid(), signal.SIGTERM)
shutdown_event.set()
except Exception as e:
logger.warning("Death monitoring error: %s", e)
@ -576,7 +576,7 @@ class WorkerProc:
ready_writer.close()
ready_writer = None
worker.worker_busy_loop()
worker.worker_busy_loop(cancel=shutdown_event)
except Exception:
# NOTE: if an Exception arises in busy_loop, we send
@ -586,6 +586,8 @@ class WorkerProc:
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
elif shutdown_event.is_set():
logger.info("WorkerProc shutting down.")
else:
logger.exception("WorkerProc failed.")
@ -637,11 +639,11 @@ class WorkerProc:
output = self.async_output_queue.get()
self.enqueue_output(output)
def worker_busy_loop(self):
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
cancel=cancel)
try:
if isinstance(method, str):
func = getattr(self.worker, method)

View File

@ -601,6 +601,9 @@ class Worker(WorkerBase):
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
def init_worker_distributed_environment(
vllm_config: VllmConfig,

View File

@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.forward_context import get_forward_context, set_forward_context
@ -42,6 +43,11 @@ class KVConnectorModelRunnerMixin:
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
if has_kv_transfer_group():
ensure_kv_transfer_shutdown()
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():

View File

@ -330,6 +330,9 @@ class TPUWorker:
ensure_kv_transfer_initialized(vllm_config)
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker

View File

@ -129,6 +129,10 @@ class WorkerBase:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return
class DelegateWorkerBase(WorkerBase):
"""
@ -519,6 +523,10 @@ class WorkerWrapperBase:
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
def shutdown(self) -> None:
if self.worker is not None:
self.worker.shutdown()
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.