[P/D] Add a shutdown method to the Connector API (#22699)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-09-08 14:07:00 +08:00 committed by GitHub
parent 8c892b1831
commit 61aa4b2901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 52 additions and 12 deletions

View File

@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_transfer_state import ( from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, KVConnectorBaseType, ensure_kv_transfer_initialized,
has_kv_transfer_group, is_v1_kv_transfer_group) ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group)
__all__ = [ __all__ = [
"get_kv_transfer_group", "has_kv_transfer_group", "get_kv_transfer_group", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"KVConnectorBaseType" "ensure_kv_transfer_shutdown", "KVConnectorBaseType"
] ]

View File

@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC):
""" """
return None, None return None, None
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
return None
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================

View File

@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
config=vllm_config, role=KVConnectorRole.WORKER) config=vllm_config, role=KVConnectorRole.WORKER)
else: else:
raise ValueError("V0 is no longer supported") raise ValueError("V0 is no longer supported")
def ensure_kv_transfer_shutdown() -> None:
global _KV_CONNECTOR_AGENT
if _KV_CONNECTOR_AGENT is not None:
_KV_CONNECTOR_AGENT.shutdown()
_KV_CONNECTOR_AGENT = None

View File

@ -231,7 +231,7 @@ class ExecutorBase(ABC):
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the executor.""" """Shutdown the executor."""
return self.collective_rpc("shutdown")
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()

View File

@ -1188,6 +1188,8 @@ class Scheduler(SchedulerInterface):
def shutdown(self) -> None: def shutdown(self) -> None:
if self.kv_event_publisher: if self.kv_event_publisher:
self.kv_event_publisher.shutdown() self.kv_event_publisher.shutdown()
if self.connector is not None:
self.connector.shutdown()
######################################################################## ########################################################################
# KV Connector Related Methods # KV Connector Related Methods

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing import multiprocessing
import os
import pickle import pickle
import queue import queue
import signal import signal
@ -507,6 +506,7 @@ class WorkerProc:
return cast(list[WorkerProcHandle], ready_proc_handles) return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self): def shutdown(self):
self.worker.shutdown()
self.rpc_broadcast_mq = None self.rpc_broadcast_mq = None
self.worker_response_mq = None self.worker_response_mq = None
destroy_model_parallel() destroy_model_parallel()
@ -536,7 +536,7 @@ class WorkerProc:
# tuple[Connection, Connection] # tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe") reader, ready_writer = kwargs.pop("ready_pipe")
death_pipe = kwargs.pop("death_pipe", None) death_pipe = kwargs.pop("death_pipe", None)
shutdown_event = threading.Event()
# Start death monitoring thread if death_pipe is provided # Start death monitoring thread if death_pipe is provided
if death_pipe is not None: if death_pipe is not None:
@ -548,7 +548,7 @@ class WorkerProc:
# Parent process has exited, terminate this worker # Parent process has exited, terminate this worker
logger.info("Parent process exited, terminating worker") logger.info("Parent process exited, terminating worker")
# Send signal to self to trigger clean shutdown # Send signal to self to trigger clean shutdown
os.kill(os.getpid(), signal.SIGTERM) shutdown_event.set()
except Exception as e: except Exception as e:
logger.warning("Death monitoring error: %s", e) logger.warning("Death monitoring error: %s", e)
@ -576,7 +576,7 @@ class WorkerProc:
ready_writer.close() ready_writer.close()
ready_writer = None ready_writer = None
worker.worker_busy_loop() worker.worker_busy_loop(cancel=shutdown_event)
except Exception: except Exception:
# NOTE: if an Exception arises in busy_loop, we send # NOTE: if an Exception arises in busy_loop, we send
@ -586,6 +586,8 @@ class WorkerProc:
if ready_writer is not None: if ready_writer is not None:
logger.exception("WorkerProc failed to start.") logger.exception("WorkerProc failed to start.")
elif shutdown_event.is_set():
logger.info("WorkerProc shutting down.")
else: else:
logger.exception("WorkerProc failed.") logger.exception("WorkerProc failed.")
@ -637,11 +639,11 @@ class WorkerProc:
output = self.async_output_queue.get() output = self.async_output_queue.get()
self.enqueue_output(output) self.enqueue_output(output)
def worker_busy_loop(self): def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
"""Main busy loop for Multiprocessing Workers""" """Main busy loop for Multiprocessing Workers"""
while True: while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
cancel=cancel)
try: try:
if isinstance(method, str): if isinstance(method, str):
func = getattr(self.worker, method) func = getattr(self.worker, method)

View File

@ -601,6 +601,9 @@ class Worker(WorkerBase):
self.model_runner.save_tensorized_model( self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, ) tensorizer_config=tensorizer_config, )
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
def init_worker_distributed_environment( def init_worker_distributed_environment(
vllm_config: VllmConfig, vllm_config: VllmConfig,

View File

@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
@ -42,6 +43,11 @@ class KVConnectorModelRunnerMixin:
# Do this here to save a collective_rpc. # Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context()) kv_connector.start_load_kv(get_forward_context())
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
if has_kv_transfer_group():
ensure_kv_transfer_shutdown()
@staticmethod @staticmethod
def maybe_wait_for_kv_save() -> None: def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group(): if has_kv_transfer_group():

View File

@ -330,6 +330,9 @@ class TPUWorker:
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
if USE_TPU_COMMONS: if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker from tpu_commons.worker import TPUWorker as TPUCommonsWorker

View File

@ -129,6 +129,10 @@ class WorkerBase:
"""Get vocabulary size from model configuration.""" """Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return
class DelegateWorkerBase(WorkerBase): class DelegateWorkerBase(WorkerBase):
""" """
@ -519,6 +523,10 @@ class WorkerWrapperBase:
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
def shutdown(self) -> None:
if self.worker is not None:
self.worker.shutdown()
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
""" """
Adjust the rpc_rank based on the given mapping. Adjust the rpc_rank based on the given mapping.