mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 10:07:01 +08:00
[P/D] Add a shutdown method to the Connector API (#22699)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
8c892b1831
commit
61aa4b2901
@ -2,11 +2,12 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||||
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
|
KVConnectorBaseType, ensure_kv_transfer_initialized,
|
||||||
has_kv_transfer_group, is_v1_kv_transfer_group)
|
ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
|
||||||
|
is_v1_kv_transfer_group)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_kv_transfer_group", "has_kv_transfer_group",
|
"get_kv_transfer_group", "has_kv_transfer_group",
|
||||||
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
|
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
|
||||||
"KVConnectorBaseType"
|
"ensure_kv_transfer_shutdown", "KVConnectorBaseType"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""
|
||||||
|
Shutdown the connector. This is called when the worker process
|
||||||
|
is shutting down to ensure that all the async operations are
|
||||||
|
completed and the connector is cleaned up properly.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|||||||
@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
|||||||
config=vllm_config, role=KVConnectorRole.WORKER)
|
config=vllm_config, role=KVConnectorRole.WORKER)
|
||||||
else:
|
else:
|
||||||
raise ValueError("V0 is no longer supported")
|
raise ValueError("V0 is no longer supported")
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_kv_transfer_shutdown() -> None:
|
||||||
|
global _KV_CONNECTOR_AGENT
|
||||||
|
if _KV_CONNECTOR_AGENT is not None:
|
||||||
|
_KV_CONNECTOR_AGENT.shutdown()
|
||||||
|
_KV_CONNECTOR_AGENT = None
|
||||||
|
|||||||
@ -231,7 +231,7 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shutdown the executor."""
|
"""Shutdown the executor."""
|
||||||
return
|
self.collective_rpc("shutdown")
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|||||||
@ -1188,6 +1188,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
if self.kv_event_publisher:
|
if self.kv_event_publisher:
|
||||||
self.kv_event_publisher.shutdown()
|
self.kv_event_publisher.shutdown()
|
||||||
|
if self.connector is not None:
|
||||||
|
self.connector.shutdown()
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# KV Connector Related Methods
|
# KV Connector Related Methods
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
import queue
|
import queue
|
||||||
import signal
|
import signal
|
||||||
@ -507,6 +506,7 @@ class WorkerProc:
|
|||||||
return cast(list[WorkerProcHandle], ready_proc_handles)
|
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
self.worker.shutdown()
|
||||||
self.rpc_broadcast_mq = None
|
self.rpc_broadcast_mq = None
|
||||||
self.worker_response_mq = None
|
self.worker_response_mq = None
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
@ -536,7 +536,7 @@ class WorkerProc:
|
|||||||
# tuple[Connection, Connection]
|
# tuple[Connection, Connection]
|
||||||
reader, ready_writer = kwargs.pop("ready_pipe")
|
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||||
death_pipe = kwargs.pop("death_pipe", None)
|
death_pipe = kwargs.pop("death_pipe", None)
|
||||||
|
shutdown_event = threading.Event()
|
||||||
# Start death monitoring thread if death_pipe is provided
|
# Start death monitoring thread if death_pipe is provided
|
||||||
if death_pipe is not None:
|
if death_pipe is not None:
|
||||||
|
|
||||||
@ -548,7 +548,7 @@ class WorkerProc:
|
|||||||
# Parent process has exited, terminate this worker
|
# Parent process has exited, terminate this worker
|
||||||
logger.info("Parent process exited, terminating worker")
|
logger.info("Parent process exited, terminating worker")
|
||||||
# Send signal to self to trigger clean shutdown
|
# Send signal to self to trigger clean shutdown
|
||||||
os.kill(os.getpid(), signal.SIGTERM)
|
shutdown_event.set()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Death monitoring error: %s", e)
|
logger.warning("Death monitoring error: %s", e)
|
||||||
|
|
||||||
@ -576,7 +576,7 @@ class WorkerProc:
|
|||||||
ready_writer.close()
|
ready_writer.close()
|
||||||
ready_writer = None
|
ready_writer = None
|
||||||
|
|
||||||
worker.worker_busy_loop()
|
worker.worker_busy_loop(cancel=shutdown_event)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# NOTE: if an Exception arises in busy_loop, we send
|
# NOTE: if an Exception arises in busy_loop, we send
|
||||||
@ -586,6 +586,8 @@ class WorkerProc:
|
|||||||
|
|
||||||
if ready_writer is not None:
|
if ready_writer is not None:
|
||||||
logger.exception("WorkerProc failed to start.")
|
logger.exception("WorkerProc failed to start.")
|
||||||
|
elif shutdown_event.is_set():
|
||||||
|
logger.info("WorkerProc shutting down.")
|
||||||
else:
|
else:
|
||||||
logger.exception("WorkerProc failed.")
|
logger.exception("WorkerProc failed.")
|
||||||
|
|
||||||
@ -637,11 +639,11 @@ class WorkerProc:
|
|||||||
output = self.async_output_queue.get()
|
output = self.async_output_queue.get()
|
||||||
self.enqueue_output(output)
|
self.enqueue_output(output)
|
||||||
|
|
||||||
def worker_busy_loop(self):
|
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
|
||||||
"""Main busy loop for Multiprocessing Workers"""
|
"""Main busy loop for Multiprocessing Workers"""
|
||||||
while True:
|
while True:
|
||||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
|
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||||
|
cancel=cancel)
|
||||||
try:
|
try:
|
||||||
if isinstance(method, str):
|
if isinstance(method, str):
|
||||||
func = getattr(self.worker, method)
|
func = getattr(self.worker, method)
|
||||||
|
|||||||
@ -601,6 +601,9 @@ class Worker(WorkerBase):
|
|||||||
self.model_runner.save_tensorized_model(
|
self.model_runner.save_tensorized_model(
|
||||||
tensorizer_config=tensorizer_config, )
|
tensorizer_config=tensorizer_config, )
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
self.model_runner.ensure_kv_transfer_shutdown()
|
||||||
|
|
||||||
|
|
||||||
def init_worker_distributed_environment(
|
def init_worker_distributed_environment(
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
|
||||||
|
get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
@ -42,6 +43,11 @@ class KVConnectorModelRunnerMixin:
|
|||||||
# Do this here to save a collective_rpc.
|
# Do this here to save a collective_rpc.
|
||||||
kv_connector.start_load_kv(get_forward_context())
|
kv_connector.start_load_kv(get_forward_context())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ensure_kv_transfer_shutdown() -> None:
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
ensure_kv_transfer_shutdown()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_wait_for_kv_save() -> None:
|
def maybe_wait_for_kv_save() -> None:
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
|
|||||||
@ -330,6 +330,9 @@ class TPUWorker:
|
|||||||
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
self.model_runner.ensure_kv_transfer_shutdown()
|
||||||
|
|
||||||
|
|
||||||
if USE_TPU_COMMONS:
|
if USE_TPU_COMMONS:
|
||||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||||
|
|||||||
@ -129,6 +129,10 @@ class WorkerBase:
|
|||||||
"""Get vocabulary size from model configuration."""
|
"""Get vocabulary size from model configuration."""
|
||||||
return self.model_config.get_vocab_size()
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Clean up resources held by the worker."""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class DelegateWorkerBase(WorkerBase):
|
class DelegateWorkerBase(WorkerBase):
|
||||||
"""
|
"""
|
||||||
@ -519,6 +523,10 @@ class WorkerWrapperBase:
|
|||||||
from vllm.utils import init_cached_hf_modules
|
from vllm.utils import init_cached_hf_modules
|
||||||
init_cached_hf_modules()
|
init_cached_hf_modules()
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
if self.worker is not None:
|
||||||
|
self.worker.shutdown()
|
||||||
|
|
||||||
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
||||||
"""
|
"""
|
||||||
Adjust the rpc_rank based on the given mapping.
|
Adjust the rpc_rank based on the given mapping.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user