From 66e233268229afee03fe9dcfa7a14a2a925238ba Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 22 Dec 2025 04:55:29 +0000 Subject: [PATCH] split large file Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/factory.py | 2 +- .../kv_connector/v1/moriio/__init__.py | 0 .../kv_connector/v1/moriio/moriio_common.py | 321 +++++++ .../v1/{ => moriio}/moriio_connector.py | 878 +----------------- .../kv_connector/v1/moriio/moriio_engine.py | 607 ++++++++++++ 5 files changed, 954 insertions(+), 854 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py rename vllm/distributed/kv_transfer/kv_connector/v1/{ => moriio}/moriio_connector.py (67%) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 954a5153ff67d..fd3d1e76d2450 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -181,7 +181,7 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MoRIIOConnector", - "vllm.distributed.kv_transfer.kv_connector.v1.moriio_connector", + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector", "MoRIIOConnector", ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py new file mode 100644 index 0000000000000..026e7faf57616 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import threading +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.utils.network_utils import ( + get_ip, + get_open_port, + make_zmq_socket, +) + +if TYPE_CHECKING: + pass + +from dataclasses import field +from enum import Enum + +logger = init_logger(__name__) + + +Transfer = tuple[int, float] +EngineId = str +ReqId = str + + +@dataclass +class WriteTask: + request_id: str + dst_engine_id: str + local_block_ids: list[int] + remote_block_ids_hint: list[int] | None + layer_name: str + event: torch.cuda.Event + remote_notify_port: int + remote_ip: str + enqueue_time: float = field(default_factory=time.perf_counter) + retried: int = 0 + + +@dataclass +class LayerTransferPlan: + """Plan for transferring a single layer.""" + + request_id: str + layer_name: str + sess_idx: int + transfer_local_offsets: list[int] + transfer_remote_offsets: list[int] + transfer_sizes: list[int] + use_batch: bool = True + + +@dataclass +class RemoteAllocInfo: + """Information about remote block allocation.""" + + block_ids: list[int] + writes_done: int = 0 + decode_dp_rank: int = 0 + transfer_offset: tuple[list[int], list[int], list[int]] | None = None + + +class ROLE(Enum): + PRODUCER = "producer" + CONSUMER = "consumer" + NOTINIT = "notinit" + + +class MoRIIOAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True, +): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + block_len: int + attn_backend_name: str + + +class RoleManager: + """Manages role state across the connector.""" + + _instance: Optional["RoleManager"] = None + _lock = threading.Lock() + + def __init__(self) -> None: + self._role: ROLE = ROLE.NOTINIT + + @classmethod + def get_instance(cls) -> "RoleManager": + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def set_role(self, role: ROLE) -> None: + """Set the current role.""" + with self._lock: + self._role = role + + def get_role(self) -> ROLE: + """Get the current role.""" + return self._role + + +def set_role(role: ROLE): + """Set the global role.""" + RoleManager.get_instance().set_role(role) + + +def get_role() -> ROLE: + """Get the global role.""" + return RoleManager.get_instance().get_role() + + +class MoRIIOMode(Enum): + READ = "read" + WRITE = "write" + + +class MoRIIOError(Exception): + """Base exception for MoRIIO operations.""" + + pass + + +class HandshakeError(MoRIIOError): + """Exception raised when handshake fails.""" + + pass + + +class TransferError(MoRIIOError): + """Exception raised when transfer fails.""" + + pass + + +def get_moriio_mode() -> MoRIIOMode: + read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE + logger.debug("MoRIIO Connector read_mode: %s", read_mode) + if read_mode: + return MoRIIOMode.READ + else: + return MoRIIOMode.WRITE + + +def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: + return (dp_rank) * tp_size + tp_rank + + +@dataclass +class MoRIIOConfig: + local_ip: str + local_kv_port: int + proxy_ip: str + local_ping_port: int + proxy_ping_port: int + http_port: int + handshake_port: int + notify_port: int + tp_rank: int + dp_rank: int + dp_size: int + tp_size: int + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": + # Port Configuration: + # local_ping_port -> Outgoing heartbeat to proxy + # proxy_ping_port -> Remote proxy's heartbeat ingress port + # http_port -> Instance's HTTP service endpoint + # local_kv_port -> service port for mori engine + # notify_port -> For synchronizing stages between prefill and decode + # handshake_port -> For initial handshake between mori engine + + # TODO : merge notify_port and handshake_port to simplify port management + # supports non-contiguous ports + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + kv_transfer_config = vllm_config.kv_transfer_config + extra_config = kv_transfer_config.kv_connector_extra_config + tp_rank = get_tensor_model_parallel_rank() + dp_rank = vllm_config.parallel_config.data_parallel_rank + base_notify_port = int(extra_config["notify_port"]) + dp_size = vllm_config.parallel_config.data_parallel_size + tp_size = get_tensor_model_parallel_world_size() + port_offset = get_port_offset(dp_rank, tp_rank) + + return cls( + local_ip=get_ip(), + local_kv_port=get_open_port(), + proxy_ip=extra_config["proxy_ip"], + local_ping_port=get_open_port(), + proxy_ping_port=int(extra_config["proxy_ping_port"]), + http_port=int(extra_config["http_port"]), + handshake_port=int(extra_config["handshake_port"]), + notify_port=base_notify_port + port_offset, + tp_rank=tp_rank, + dp_rank=dp_rank, + dp_size=dp_size, + tp_size=tp_size, + ) + + +class MoRIIOConstants: + """Constants for MoRIIO connector.""" + + # ZMQ message types + GET_META_MSG = b"get_meta_msg" + POP_DONE_RECV = b"pop_done_recv" + OVER = b"OVER" + COMPLETION_PREFIX = "cmpl" + + PING_INTERVAL = 5 + MAX_PING_RETRIES = 100 + DEFAULT_HANDSHAKE_PORT = "6301" + DEFAULT_NOTIFY_PORT = "61005" + + VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 + + +@dataclass +class ReqMeta: + """Metadata for a single request.""" + + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_handshake_port: int + remote_notify_port: int + remote_engine_id: str + tp_size: int + remote_dp_size: int + + +class MoRIIOConnectorMetadata(KVConnectorMetadata): + def __init__(self): + self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_save: dict[ReqId, ReqMeta] = {} + self.reqs_to_send: dict[ReqId, float] = {} + + def __repr__(self): + return_str = "" + for req_id, req_meta in self.reqs_to_recv.items(): + return_str += ( + f"{req_id = },{req_meta.local_block_ids = }," + f"{req_meta.remote_host = },{req_meta.remote_port = }" + f"{req_meta.remote_engine_id = },{req_meta.tp_size = }" + ) + return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," + + for req_id, expiry in self.reqs_to_send.items(): + return_str += f"{req_id = },{expiry = }" + return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str}," + return return_str + + def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + write_mode=False, + ): + _req = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_handshake_port=kv_transfer_params["remote_handshake_port"], + remote_notify_port=kv_transfer_params["remote_notify_port"], + tp_size=kv_transfer_params.get("tp_size", 1), + remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), + ) + if write_mode: + self.reqs_to_save[request_id] = _req + else: + self.reqs_to_recv[request_id] = _req + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: zmq.Context | None = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py similarity index 67% rename from vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py rename to vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 1ce1f435cf544..4b6bd906d5d44 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -1,17 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib import logging import math import queue import threading import time from collections import defaultdict -from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional -from weakref import ref as weakref_ref import msgpack import msgspec @@ -19,7 +15,6 @@ import numpy as np import torch import zmq -from vllm import envs from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -27,8 +22,29 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + ROLE, + EngineId, + HandshakeError, + MoRIIOAgentMetadata, + MoRIIOConfig, + MoRIIOConnectorMetadata, + MoRIIOConstants, + MoRIIOMode, + ReqId, + ReqMeta, + WriteTask, + get_moriio_mode, + get_port_offset, + get_role, + set_role, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine import ( + MoRIIOWrapper, + MoRIIOWriter, +) from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, get_world_group, @@ -37,7 +53,6 @@ from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.utils.network_utils import ( get_ip, - get_open_port, make_zmq_path, make_zmq_socket, ) @@ -50,43 +65,13 @@ if TYPE_CHECKING: from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -from dataclasses import field -from enum import Enum -from queue import Empty, Queue - logger = init_logger(__name__) -Transfer = tuple[int, float] -EngineId = str -ReqId = str - - -class MoRIIOConstants: - """Constants for MoRIIO connector.""" - - # ZMQ message types - GET_META_MSG = b"get_meta_msg" - POP_DONE_RECV = b"pop_done_recv" - OVER = b"OVER" - COMPLETION_PREFIX = "cmpl" - - PING_INTERVAL = 5 - MAX_PING_RETRIES = 100 - DEFAULT_HANDSHAKE_PORT = "6301" - DEFAULT_NOTIFY_PORT = "61005" - - VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 - - try: from mori.io import ( BackendType, - EngineDesc, IOEngine, IOEngineConfig, - MemoryDesc, - PollCqMode, - RdmaBackendConfig, ) logger.info("MoRIIO is available") @@ -96,803 +81,8 @@ except ImportError: MoRIIO_enabled = False -@dataclass -class WriteTask: - request_id: str - dst_engine_id: str - local_block_ids: list[int] - remote_block_ids_hint: list[int] | None - layer_name: str - event: torch.cuda.Event - remote_notify_port: int - remote_ip: str - enqueue_time: float = field(default_factory=time.perf_counter) - retried: int = 0 - - -@dataclass -class LayerTransferPlan: - """Plan for transferring a single layer.""" - - request_id: str - layer_name: str - sess_idx: int - transfer_local_offsets: list[int] - transfer_remote_offsets: list[int] - transfer_sizes: list[int] - use_batch: bool = True - - -@dataclass -class RemoteAllocInfo: - """Information about remote block allocation.""" - - block_ids: list[int] - writes_done: int = 0 - decode_dp_rank: int = 0 - transfer_offset: tuple[list[int], list[int], list[int]] | None = None - - -class ROLE(Enum): - PRODUCER = "producer" - CONSUMER = "consumer" - NOTINIT = "notinit" - - -class MoRIIOAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property.d - dict=True, -): - engine_id: str - agent_metadata: bytes - kv_caches_base_addr: list[int] - num_blocks: int - block_len: int - attn_backend_name: str - - -class RoleManager: - """Manages role state across the connector.""" - - _instance: Optional["RoleManager"] = None - _lock = threading.Lock() - - def __init__(self) -> None: - self._role: ROLE = ROLE.NOTINIT - - @classmethod - def get_instance(cls) -> "RoleManager": - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance - - def set_role(self, role: ROLE) -> None: - """Set the current role.""" - with self._lock: - self._role = role - - def get_role(self) -> ROLE: - """Get the current role.""" - return self._role - - -def set_role(role: ROLE): - """Set the global role.""" - RoleManager.get_instance().set_role(role) - - -def get_role() -> ROLE: - """Get the global role.""" - return RoleManager.get_instance().get_role() - - -class MoRIIOMode(Enum): - READ = "read" - WRITE = "write" - - -class MoRIIOError(Exception): - """Base exception for MoRIIO operations.""" - - pass - - -class HandshakeError(MoRIIOError): - """Exception raised when handshake fails.""" - - pass - - -class TransferError(MoRIIOError): - """Exception raised when transfer fails.""" - - pass - - -def get_moriio_mode() -> MoRIIOMode: - read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE - logger.debug("MoRIIO Connector read_mode: %s", read_mode) - if read_mode: - return MoRIIOMode.READ - else: - return MoRIIOMode.WRITE - - -def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: - return (dp_rank) * tp_size + tp_rank - - -@dataclass -class MoRIIOConfig: - local_ip: str - local_kv_port: int - proxy_ip: str - local_ping_port: int - proxy_ping_port: int - http_port: int - handshake_port: int - notify_port: int - tp_rank: int - dp_rank: int - dp_size: int - tp_size: int - - @classmethod - def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": - # Port Configuration: - # local_ping_port -> Outgoing heartbeat to proxy - # proxy_ping_port -> Remote proxy's heartbeat ingress port - # http_port -> Instance's HTTP service endpoint - # local_kv_port -> service port for mori engine - # notify_port -> For synchronizing stages between prefill and decode - # handshake_port -> For initial handshake between mori engine - - # TODO : merge notify_port and handshake_port to simplify port management - # supports non-contiguous ports - assert vllm_config.kv_transfer_config is not None, ( - "kv_transfer_config must be set for MoRIIOConnector" - ) - kv_transfer_config = vllm_config.kv_transfer_config - extra_config = kv_transfer_config.kv_connector_extra_config - tp_rank = get_tensor_model_parallel_rank() - dp_rank = vllm_config.parallel_config.data_parallel_rank - base_notify_port = int(extra_config["notify_port"]) - dp_size = vllm_config.parallel_config.data_parallel_size - tp_size = get_tensor_model_parallel_world_size() - port_offset = get_port_offset(dp_rank, tp_rank) - - return cls( - local_ip=get_ip(), - local_kv_port=get_open_port(), - proxy_ip=extra_config["proxy_ip"], - local_ping_port=get_open_port(), - proxy_ping_port=int(extra_config["proxy_ping_port"]), - http_port=int(extra_config["http_port"]), - handshake_port=int(extra_config["handshake_port"]), - notify_port=base_notify_port + port_offset, - tp_rank=tp_rank, - dp_rank=dp_rank, - dp_size=dp_size, - tp_size=tp_size, - ) - - -"""Write task execution logic for MoRIIO connector.""" - - -class MoRIIOWriter: - """Handles write operations for KV cache transfers. - Implements distributed KV cache transfer using the MoRIIO library - for RDMA-based communication between prefill and decode instances.""" - - def __init__(self, worker: "MoRIIOConnectorWorker"): - """Initialize the writer. - - Args: - worker: Reference to the parent worker - """ - self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) - self._write_task_q: Queue[WriteTask] = Queue() - self._write_worker_started = False - self._write_worker_lock = threading.Lock() - self._deferred_tasks: list[WriteTask] = [] - - @property - def worker(self) -> "MoRIIOConnectorWorker": - """Get the worker instance. - - Returns: - The parent worker instance - - Raises: - RuntimeError: If worker has been garbage collected - """ - worker = self._worker_ref() - if worker is None: - raise RuntimeError("Parent worker has been garbage collected") - return worker - - def ensure_worker_started(self) -> None: - """Ensure the background write worker is running.""" - if self._write_worker_started: - return - self._write_worker_started = True - with self._write_worker_lock: - thread = threading.Thread( - target=self._write_worker_loop, daemon=True, name="moriio-write-worker" - ) - thread.start() - logger.info("Started MoRIIO write worker thread") - - def schedule_write(self, task: WriteTask) -> None: - """Schedule a write task. - - Args: - task: The write task to schedule - """ - self.ensure_worker_started() - self._write_task_q.put(task) - - def _write_worker_loop(self) -> None: - """Main loop for the write worker thread.""" - - while True: - # Process deferred tasks first - self._process_deferred_tasks() - - # Get new task - try: - task = self._write_task_q.get(timeout=0.01) - except Empty: - continue - - # Check if remote blocks are ready - if not self._is_remote_ready(task): - # task.retry_count += 1 - self._deferred_tasks.append(task) - # logger.debug( - # "Deferred task for request %s (retry %d)", - # task.request_id, task.retry_count - # ) - continue - - # Execute the task - - self._execute_write_task(task) - - def _process_deferred_tasks(self) -> None: - """Process tasks that were previously deferred.""" - if not self._deferred_tasks: - return - - still_deferred: list[WriteTask] = [] - for task in self._deferred_tasks: - if self._is_remote_ready(task): - self._execute_write_task(task) - else: - still_deferred.append(task) - - self._deferred_tasks = still_deferred - - def _is_remote_ready(self, task: WriteTask) -> bool: - """Check if remote blocks are allocated for this task. - - Args: - task: The write task - - Returns: - True if remote blocks are ready - """ - return ( - task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict - ) - - def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: - """Get remote allocation info for a request. - - Args: - request_id: The request ID - - Returns: - Remote allocation information - - Raises: - KeyError: If allocation info is missing - """ - try: - return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id] - except KeyError as e: - raise KeyError( - f"Remote allocation info missing for request {request_id}" - ) from e - - def _execute_write_task(self, task: WriteTask) -> None: - """Execute a single write task. - - Args: - task: The write task to execute - - """ - # Get remote allocation info - request_info = self._get_remote_alloc_info(task.request_id) - - if request_info.block_ids is None: - logger.debug("Request %s remote block IDs not ready", task.request_id) - return - - # Wait for CUDA event - # The attention computation of the current layer cannot - # overlap with the kv transfer task, - # otherwise it will cause precision issues. - # This event is used to synchronize the kv transfer and computation tasks. - task.event.synchronize() - - # Update engine ID with DP rank - task.dst_engine_id = self.worker.get_engine_name_with_dp( - task.dst_engine_id, request_info.decode_dp_rank - ) - - # Get or create sessions - sessions, remote_moriio_meta = self.worker._get_built_session( - task.dst_engine_id - ) - - # Prepare transfer plan - plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) - - # Execute transfer - self._do_layer_write(plan, sessions) - - # Finalize if all layers complete - self._finalize_if_complete(task, request_info) - - def _prepare_transfer_plan( - self, - task: WriteTask, - request_info: RemoteAllocInfo, - remote_moriio_meta: MoRIIOAgentMetadata, - ) -> LayerTransferPlan: - """Prepare the transfer plan for a layer. - - Args: - task: The write task - request_info: Remote allocation information - - Returns: - The transfer plan - """ - # Compute offsets if not cached - if request_info.transfer_offset is None: - offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, - task.local_block_ids, - request_info.block_ids, - remote_moriio_meta, - ) - request_info.transfer_offset = offsets - - # Get session index - layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys()) - sess_idx = layer_names.index(task.layer_name) - - local_off, remote_off, sizes = request_info.transfer_offset - - return LayerTransferPlan( - request_id=task.request_id, - layer_name=task.layer_name, - sess_idx=sess_idx, - transfer_local_offsets=local_off, - transfer_remote_offsets=remote_off, - transfer_sizes=sizes, - use_batch=True, - ) - - def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None: - """Perform the actual layer write. - - Args: - plan: The transfer plan - sessions: List of transfer sessions - """ - if plan.use_batch: - self.worker.moriio_wrapper.write_remote_data( - plan.transfer_sizes, - plan.transfer_local_offsets, - plan.transfer_remote_offsets, - sessions[plan.sess_idx], - ) - else: - for i in range(len(plan.transfer_local_offsets)): - self.worker.moriio_wrapper.write_remote_data_single( - plan.transfer_sizes[i], - plan.transfer_local_offsets[i], - plan.transfer_remote_offsets[i], - plan.sess_idx, - ) - - def _finalize_if_complete( - self, task: WriteTask, request_info: RemoteAllocInfo - ) -> None: - """Finalize transfer if all layers are complete. - - Args: - task: The write task - request_info: Remote allocation information - """ - request_info.writes_done += 1 - - if request_info.writes_done >= self.worker.num_layers: - # Wait for transfer to complete - self.worker.moriio_wrapper.waiting_for_transfer_complete() - - remote_port = task.remote_notify_port + get_port_offset( - request_info.decode_dp_rank, self.worker.tp_rank - ) - # Consider using RDMA immediate data in decode side - # to eliminate the need for this notification. - # Consider including the first gen token from prefill in the notification - - # Send completion notification - self.worker.moriio_wrapper.send_notify( - task.request_id, task.remote_ip, remote_port - ) - # mark request as done, then we can free the blocks - with self.worker.moriio_wrapper.lock: - self.worker.moriio_wrapper.done_req_ids.append(task.request_id) - del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ - task.request_id - ] - logger.debug( - "Completed transfer for request %s, notified port %d", - task.request_id, - remote_port, - ) - - -class MoRIIOWrapper: - """Wrapper for MoRIIO engine operations. - - Handles both producer and consumer roles for KV cache transfers. - - Args: - moriio_engine: MoRIIO engine instance - tp_rank: Tensor parallel rank - dp_rank: Data parallel rank - """ - - def __init__( - self, - moriio_engine: Optional["IOEngine"] = None, - tp_rank: int = 0, - dp_rank: int = 0, - ): - self.tp_rank = tp_rank - self.dp_rank = dp_rank - self.moriio_engine = moriio_engine - self.remote_memory_metadata = None - self.local_memory_registered = False - self.local_memory_metadata = None - self.transfer_status: list[Any] = [] - self.remote_engine_ip: str | None = None - self.notify_port: int | None = None - self.lock = threading.Lock() - self.done_req_ids: list[str] = [] - self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} - self.done_write_cache_req_ids: list[str] = [] - self.notify_thread: threading.Thread | None = None - self.sessions: list[IOEngine.Session] = [] - self.paths: dict[str, zmq.Socket] = {} - - def set_moriio_engine(self, moriio_engine): - assert moriio_engine is not None, ( - "You Cannot pass None engine to MoRIIOWrapper!" - ) - self.moriio_engine = moriio_engine - - def set_backend_type(self, backend_type): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER - post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE - num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS - poll_mode = PollCqMode.POLLING - rdma_cfg = RdmaBackendConfig( - qp_per_transfer, - post_batch_size, - num_worker_threads, - poll_mode, - ) - self.moriio_engine.create_backend(backend_type, rdma_cfg) - - def get_agent_metadata(self): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - engine_metadata = self.moriio_engine.get_engine_desc() - engine_metadata_packed = engine_metadata.pack() - return engine_metadata_packed - - def register_remote_engine(self, remote_packed_engine_metadata): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) - self.moriio_engine.register_remote_engine(consumer_engine_metadata) - return consumer_engine_metadata.key - - def register_local_tensor(self, tensor: torch.Tensor): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - try: - self.local_memory_metadata = self.moriio_engine.register_torch_tensor( - tensor - ) - assert self.local_memory_metadata is not None, ( - "register_torch_tensor returned None" - ) - local_memory_metadata_packed = self.local_memory_metadata.pack() - except Exception as e: - raise MoRIIOError(f"Failed to register local memory: {e}") from e - self.local_memory_registered = True - return local_memory_metadata_packed - - def get_unpack_memory_metadata(self, packed_memory_metadata): - return MemoryDesc.unpack(packed_memory_metadata) - - def build_session(self, local_memory_metadata, remote_memory_metadata): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - return self.moriio_engine.create_session( - local_memory_metadata, remote_memory_metadata - ) - - def read_remote_data( - self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): - assert self.local_memory_registered, "You have not register local memory data!" - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - transfer_status = session.batch_read( - local_offset, - remote_offset, - transfer_size_byte, - self.moriio_engine.allocate_transfer_uid(), - ) - - return transfer_status - - def write_remote_data( - self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): - assert self.local_memory_registered, "You have not register local memory data!" - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - write_uid = self.moriio_engine.allocate_transfer_uid() - - transfer_status = session.batch_write( - local_offset, remote_offset, transfer_size_byte, write_uid - ) - with self.lock: - self.transfer_status.append(transfer_status) - - def write_remote_data_single( - self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 - ): - assert self.local_memory_registered, "You have not register local memory data!" - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - transfer_status = self.sessions[sess_idx].write( - local_offset, - remote_offset, - transfer_size_byte, - self.moriio_engine.allocate_transfer_uid(), - ) - with self.lock: - self.transfer_status.append(transfer_status) - - def waiting_for_transfer_complete(self): - if not self.transfer_status: - return - - transfers_to_wait = [] - with self.lock: - transfers_to_wait = self.transfer_status[:] - self.transfer_status.clear() - - for status in transfers_to_wait: - try: - status.Wait() - if not status.Succeeded(): - logger.error( - "Transfer failed: %s, Code: %s", status.Message(), status.Code() - ) - raise TransferError("MoRIIO transfer failed!") - except Exception as e: - logger.error("Transfer %s failed: %s", status, e) - raise - - def async_wait_reqid(self): - assert self.notify_port is not None, "Notify port cannot be None" - - if self.notify_thread is not None: - return - - def _async_wait(): - host = "*" - path = make_zmq_path("tcp", host, self.notify_port) - logger.info("Node starting to listen notify from path = %s", path) - - with zmq_ctx(zmq.ROUTER, path) as sock: - while True: - try: - identity, msg = sock.recv_multipart() - self._handle_message(msg) - except Exception as e: - logger.error("Error processing message: %s", e) - raise HandshakeError(f"Error processing message: {e}") from e - - self.notify_thread = threading.Thread( - target=_async_wait, daemon=True, name="moriio-notify-listener" - ) - self.notify_thread.start() - - def _handle_message(self, msg: bytes): - """Handles incoming messages from remote nodes.""" - # Handles incoming remote messages: - # Prefill Role: - # [write] mode: receives block information (allocation) - # [read] mode: receives block release messages from decode side - # Decode Role: - # [write] mode: receives KV cache write completion notifications - handled = False - try: - data = msgpack.loads(msg) - if isinstance(data, dict) and "req_id" in data: - self._handle_structured_message(data) - - return - except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): - logger.debug("Failed to decode msgpack message, will try as string") - pass - - try: - msg_str = msg.decode("UTF-8") - if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX): - self._handle_completion_message(msg_str) - handled = True - except UnicodeDecodeError: - logger.warning("Received non-UTF8 message: %s", msg_str) - if not handled: - raise MoRIIOError(f"Unhandled message format: {msg_str}") - - def _handle_structured_message(self, data: dict): - assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages" - req_id = data["req_id"] - block_notify_list = data.get("block_notify_list", []) - decode_dp_rank = data.get("decode_rank", 0) - assert len(block_notify_list) > 0, ( - "block_notify_list cannot be empty in remote allocate message" - ) - - with self.lock: - self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo( - block_ids=block_notify_list, decode_dp_rank=decode_dp_rank - ) - - def _handle_completion_message(self, msg: str): - with self.lock: - if get_role() == ROLE.PRODUCER: - self.done_req_ids.append(msg) - else: - self.done_write_cache_req_ids.append(msg) - - def send_notify(self, req_ids, remote_ip, remote_port): - if not remote_ip or not remote_port: - logger.warning("Missing remote_ip or remote_port for notification") - return - - path = make_zmq_path("tcp", remote_ip, remote_port) - - if path not in self.paths: - ctx = zmq.Context.instance() - sock = make_zmq_socket( - ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False - ) - self.paths[path] = sock - - req_list = req_ids if isinstance(req_ids, list) else [req_ids] - - sock = self.paths[path] - try: - for req_id in req_list: - if not isinstance(req_id, str): - logger.warning( - "Invalid req_id type: %s, expected str", type(req_id) - ) - continue - sock.send(req_id.encode("utf-8")) - except Exception as e: - logger.error("Failed to send notification to %s: %s", path, e) - self.paths.pop(path, None) - raise - - def pop_finished_req_ids(self): - # producer invocation: get the set of completed requests at the decode - with self.lock: - done_send = set(self.done_req_ids) - self.done_req_ids = [] - return done_send - - def pop_finished_write_req_ids(self): - # Call the consumer in write mode to get the collection after write completion - with self.lock: - done_write_cache = set(self.done_write_cache_req_ids) - self.done_write_cache_req_ids = [] - return done_write_cache - - def shutdown(self): - logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") - for path, sock in self.paths.items(): - try: - sock.close(linger=0) - logger.debug("Closed ZMQ socket for path: %s", path) - except Exception as e: - logger.warning("Error closing ZMQ socket for path %s: %s", path, e) - self.paths.clear() - - -@dataclass -class ReqMeta: - """Metadata for a single request.""" - - local_block_ids: list[int] - remote_block_ids: list[int] - remote_host: str - remote_port: int - remote_handshake_port: int - remote_notify_port: int - remote_engine_id: str - tp_size: int - remote_dp_size: int - - -class MoRIIOConnectorMetadata(KVConnectorMetadata): - def __init__(self): - self.reqs_to_recv: dict[ReqId, ReqMeta] = {} - self.reqs_to_save: dict[ReqId, ReqMeta] = {} - self.reqs_to_send: dict[ReqId, float] = {} - - def __repr__(self): - return_str = "" - for req_id, req_meta in self.reqs_to_recv.items(): - return_str += ( - f"{req_id = },{req_meta.local_block_ids = }," - f"{req_meta.remote_host = },{req_meta.remote_port = }" - f"{req_meta.remote_engine_id = },{req_meta.tp_size = }" - ) - return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," - - for req_id, expiry in self.reqs_to_send.items(): - return_str += f"{req_id = },{expiry = }" - return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str}," - return return_str - - def add_new_req( - self, - request_id: ReqId, - local_block_ids: list[int], - kv_transfer_params: dict[str, Any], - write_mode=False, - ): - _req = ReqMeta( - local_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params["remote_block_ids"], - remote_engine_id=kv_transfer_params["remote_engine_id"], - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], - remote_handshake_port=kv_transfer_params["remote_handshake_port"], - remote_notify_port=kv_transfer_params["remote_notify_port"], - tp_size=kv_transfer_params.get("tp_size", 1), - remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), - ) - if write_mode: - self.reqs_to_save[request_id] = _req - else: - self.reqs_to_recv[request_id] = _req +def is_moriio_available() -> bool: + return MoRIIO_enabled class MoRIIOConnector(KVConnectorBase_V1): @@ -1371,7 +561,7 @@ class MoRIIOConnectorWorker: """Implementation of Worker side methods""" def __init__(self, vllm_config: VllmConfig, engine_id: str): - if not MoRIIO_enabled: + if not is_moriio_available(): raise RuntimeError( "MoRIIO is not available. Please ensure the 'mori' package " "is installed and properly configured." @@ -2323,21 +1513,3 @@ class MoRIIOConnectorWorker: remote_host, str(remote_notify_port + self.tp_rank), ) - - -@contextlib.contextmanager -def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: - """Context manager for a ZMQ socket""" - - if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): - raise ValueError(f"Unexpected socket type: {socket_type}") - - ctx: zmq.Context | None = None - try: - ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket( - ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER - ) - finally: - if ctx is not None: - ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py new file mode 100644 index 0000000000000..4357c0335ef99 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py @@ -0,0 +1,607 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import TYPE_CHECKING, Any, Optional +from weakref import ref as weakref_ref + +import msgpack +import torch +import zmq + +from vllm import envs +from vllm.logger import init_logger +from vllm.utils.network_utils import ( + make_zmq_path, + make_zmq_socket, +) + +if TYPE_CHECKING: + pass + +from queue import Empty, Queue + +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + ROLE, + HandshakeError, + LayerTransferPlan, + MoRIIOAgentMetadata, + MoRIIOConstants, + MoRIIOError, + RemoteAllocInfo, + TransferError, + WriteTask, + get_port_offset, + get_role, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + MoRIIOConnectorWorker, +) + +logger = init_logger(__name__) +try: + from mori.io import ( + EngineDesc, + IOEngine, + MemoryDesc, + PollCqMode, + RdmaBackendConfig, + ) + + logger.info("MoRIIO is available") +except ImportError: + logger.error("MoRIIO is not available") + + +"""Write task execution logic for MoRIIO connector.""" + + +class MoRIIOWriter: + """Handles write operations for KV cache transfers. + Implements distributed KV cache transfer using the MoRIIO library + for RDMA-based communication between prefill and decode instances.""" + + def __init__(self, worker: "MoRIIOConnectorWorker"): + """Initialize the writer. + + Args: + worker: Reference to the parent worker + """ + self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) + self._write_task_q: Queue[WriteTask] = Queue() + self._write_worker_started = False + self._write_worker_lock = threading.Lock() + self._deferred_tasks: list[WriteTask] = [] + + @property + def worker(self) -> "MoRIIOConnectorWorker": + """Get the worker instance. + + Returns: + The parent worker instance + + Raises: + RuntimeError: If worker has been garbage collected + """ + worker = self._worker_ref() + if worker is None: + raise RuntimeError("Parent worker has been garbage collected") + return worker + + def ensure_worker_started(self) -> None: + """Ensure the background write worker is running.""" + if self._write_worker_started: + return + self._write_worker_started = True + with self._write_worker_lock: + thread = threading.Thread( + target=self._write_worker_loop, daemon=True, name="moriio-write-worker" + ) + thread.start() + logger.info("Started MoRIIO write worker thread") + + def schedule_write(self, task: WriteTask) -> None: + """Schedule a write task. + + Args: + task: The write task to schedule + """ + self.ensure_worker_started() + self._write_task_q.put(task) + + def _write_worker_loop(self) -> None: + """Main loop for the write worker thread.""" + + while True: + # Process deferred tasks first + self._process_deferred_tasks() + + # Get new task + try: + task = self._write_task_q.get(timeout=0.01) + except Empty: + continue + + # Check if remote blocks are ready + if not self._is_remote_ready(task): + # task.retry_count += 1 + self._deferred_tasks.append(task) + # logger.debug( + # "Deferred task for request %s (retry %d)", + # task.request_id, task.retry_count + # ) + continue + + # Execute the task + + self._execute_write_task(task) + + def _process_deferred_tasks(self) -> None: + """Process tasks that were previously deferred.""" + if not self._deferred_tasks: + return + + still_deferred: list[WriteTask] = [] + for task in self._deferred_tasks: + if self._is_remote_ready(task): + self._execute_write_task(task) + else: + still_deferred.append(task) + + self._deferred_tasks = still_deferred + + def _is_remote_ready(self, task: WriteTask) -> bool: + """Check if remote blocks are allocated for this task. + + Args: + task: The write task + + Returns: + True if remote blocks are ready + """ + return ( + task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict + ) + + def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: + """Get remote allocation info for a request. + + Args: + request_id: The request ID + + Returns: + Remote allocation information + + Raises: + KeyError: If allocation info is missing + """ + try: + return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id] + except KeyError as e: + raise KeyError( + f"Remote allocation info missing for request {request_id}" + ) from e + + def _execute_write_task(self, task: WriteTask) -> None: + """Execute a single write task. + + Args: + task: The write task to execute + + """ + # Get remote allocation info + request_info = self._get_remote_alloc_info(task.request_id) + + if request_info.block_ids is None: + logger.debug("Request %s remote block IDs not ready", task.request_id) + return + + # Wait for CUDA event + # The attention computation of the current layer cannot + # overlap with the kv transfer task, + # otherwise it will cause precision issues. + # This event is used to synchronize the kv transfer and computation tasks. + task.event.synchronize() + + # Update engine ID with DP rank + task.dst_engine_id = self.worker.get_engine_name_with_dp( + task.dst_engine_id, request_info.decode_dp_rank + ) + + # Get or create sessions + sessions, remote_moriio_meta = self.worker._get_built_session( + task.dst_engine_id + ) + + # Prepare transfer plan + plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) + + # Execute transfer + self._do_layer_write(plan, sessions) + + # Finalize if all layers complete + self._finalize_if_complete(task, request_info) + + def _prepare_transfer_plan( + self, + task: WriteTask, + request_info: RemoteAllocInfo, + remote_moriio_meta: MoRIIOAgentMetadata, + ) -> LayerTransferPlan: + """Prepare the transfer plan for a layer. + + Args: + task: The write task + request_info: Remote allocation information + + Returns: + The transfer plan + """ + # Compute offsets if not cached + if request_info.transfer_offset is None: + offsets = self.worker._compute_block_transfer_offsets( + task.layer_name, + task.local_block_ids, + request_info.block_ids, + remote_moriio_meta, + ) + request_info.transfer_offset = offsets + + # Get session index + layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys()) + sess_idx = layer_names.index(task.layer_name) + + local_off, remote_off, sizes = request_info.transfer_offset + + return LayerTransferPlan( + request_id=task.request_id, + layer_name=task.layer_name, + sess_idx=sess_idx, + transfer_local_offsets=local_off, + transfer_remote_offsets=remote_off, + transfer_sizes=sizes, + use_batch=True, + ) + + def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None: + """Perform the actual layer write. + + Args: + plan: The transfer plan + sessions: List of transfer sessions + """ + if plan.use_batch: + self.worker.moriio_wrapper.write_remote_data( + plan.transfer_sizes, + plan.transfer_local_offsets, + plan.transfer_remote_offsets, + sessions[plan.sess_idx], + ) + else: + for i in range(len(plan.transfer_local_offsets)): + self.worker.moriio_wrapper.write_remote_data_single( + plan.transfer_sizes[i], + plan.transfer_local_offsets[i], + plan.transfer_remote_offsets[i], + plan.sess_idx, + ) + + def _finalize_if_complete( + self, task: WriteTask, request_info: RemoteAllocInfo + ) -> None: + """Finalize transfer if all layers are complete. + + Args: + task: The write task + request_info: Remote allocation information + """ + request_info.writes_done += 1 + + if request_info.writes_done >= self.worker.num_layers: + # Wait for transfer to complete + self.worker.moriio_wrapper.waiting_for_transfer_complete() + + remote_port = task.remote_notify_port + get_port_offset( + request_info.decode_dp_rank, self.worker.tp_rank + ) + # Consider using RDMA immediate data in decode side + # to eliminate the need for this notification. + # Consider including the first gen token from prefill in the notification + + # Send completion notification + self.worker.moriio_wrapper.send_notify( + task.request_id, task.remote_ip, remote_port + ) + # mark request as done, then we can free the blocks + with self.worker.moriio_wrapper.lock: + self.worker.moriio_wrapper.done_req_ids.append(task.request_id) + del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ + task.request_id + ] + logger.debug( + "Completed transfer for request %s, notified port %d", + task.request_id, + remote_port, + ) + + +class MoRIIOWrapper: + """Wrapper for MoRIIO engine operations. + + Handles both producer and consumer roles for KV cache transfers. + + Args: + moriio_engine: MoRIIO engine instance + tp_rank: Tensor parallel rank + dp_rank: Data parallel rank + """ + + def __init__( + self, + moriio_engine: Optional["IOEngine"] = None, + tp_rank: int = 0, + dp_rank: int = 0, + ): + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moriio_engine = moriio_engine + self.remote_memory_metadata = None + self.local_memory_registered = False + self.local_memory_metadata = None + self.transfer_status: list[Any] = [] + self.remote_engine_ip: str | None = None + self.notify_port: int | None = None + self.lock = threading.Lock() + self.done_req_ids: list[str] = [] + self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} + self.done_write_cache_req_ids: list[str] = [] + self.notify_thread: threading.Thread | None = None + self.sessions: list[IOEngine.Session] = [] + self.paths: dict[str, zmq.Socket] = {} + + def set_moriio_engine(self, moriio_engine): + assert moriio_engine is not None, ( + "You Cannot pass None engine to MoRIIOWrapper!" + ) + self.moriio_engine = moriio_engine + + def set_backend_type(self, backend_type): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER + post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE + num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS + poll_mode = PollCqMode.POLLING + rdma_cfg = RdmaBackendConfig( + qp_per_transfer, + post_batch_size, + num_worker_threads, + poll_mode, + ) + self.moriio_engine.create_backend(backend_type, rdma_cfg) + + def get_agent_metadata(self): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + engine_metadata = self.moriio_engine.get_engine_desc() + engine_metadata_packed = engine_metadata.pack() + return engine_metadata_packed + + def register_remote_engine(self, remote_packed_engine_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) + self.moriio_engine.register_remote_engine(consumer_engine_metadata) + return consumer_engine_metadata.key + + def register_local_tensor(self, tensor: torch.Tensor): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + try: + self.local_memory_metadata = self.moriio_engine.register_torch_tensor( + tensor + ) + assert self.local_memory_metadata is not None, ( + "register_torch_tensor returned None" + ) + local_memory_metadata_packed = self.local_memory_metadata.pack() + except Exception as e: + raise MoRIIOError(f"Failed to register local memory: {e}") from e + self.local_memory_registered = True + return local_memory_metadata_packed + + def get_unpack_memory_metadata(self, packed_memory_metadata): + return MemoryDesc.unpack(packed_memory_metadata) + + def build_session(self, local_memory_metadata, remote_memory_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + return self.moriio_engine.create_session( + local_memory_metadata, remote_memory_metadata + ) + + def read_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + transfer_status = session.batch_read( + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) + + return transfer_status + + def write_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + write_uid = self.moriio_engine.allocate_transfer_uid() + + transfer_status = session.batch_write( + local_offset, remote_offset, transfer_size_byte, write_uid + ) + with self.lock: + self.transfer_status.append(transfer_status) + + def write_remote_data_single( + self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + transfer_status = self.sessions[sess_idx].write( + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) + with self.lock: + self.transfer_status.append(transfer_status) + + def waiting_for_transfer_complete(self): + if not self.transfer_status: + return + + transfers_to_wait = [] + with self.lock: + transfers_to_wait = self.transfer_status[:] + self.transfer_status.clear() + + for status in transfers_to_wait: + try: + status.Wait() + if not status.Succeeded(): + logger.error( + "Transfer failed: %s, Code: %s", status.Message(), status.Code() + ) + raise TransferError("MoRIIO transfer failed!") + except Exception as e: + logger.error("Transfer %s failed: %s", status, e) + raise + + def async_wait_reqid(self): + assert self.notify_port is not None, "Notify port cannot be None" + + if self.notify_thread is not None: + return + + def _async_wait(): + host = "*" + path = make_zmq_path("tcp", host, self.notify_port) + logger.info("Node starting to listen notify from path = %s", path) + + with zmq_ctx(zmq.ROUTER, path) as sock: + while True: + try: + identity, msg = sock.recv_multipart() + self._handle_message(msg) + except Exception as e: + logger.error("Error processing message: %s", e) + raise HandshakeError(f"Error processing message: {e}") from e + + self.notify_thread = threading.Thread( + target=_async_wait, daemon=True, name="moriio-notify-listener" + ) + self.notify_thread.start() + + def _handle_message(self, msg: bytes): + """Handles incoming messages from remote nodes.""" + # Handles incoming remote messages: + # Prefill Role: + # [write] mode: receives block information (allocation) + # [read] mode: receives block release messages from decode side + # Decode Role: + # [write] mode: receives KV cache write completion notifications + handled = False + try: + data = msgpack.loads(msg) + if isinstance(data, dict) and "req_id" in data: + self._handle_structured_message(data) + + return + except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): + logger.debug("Failed to decode msgpack message, will try as string") + pass + + try: + msg_str = msg.decode("UTF-8") + if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX): + self._handle_completion_message(msg_str) + handled = True + except UnicodeDecodeError: + logger.warning("Received non-UTF8 message: %s", msg_str) + if not handled: + raise MoRIIOError(f"Unhandled message format: {msg_str}") + + def _handle_structured_message(self, data: dict): + assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages" + req_id = data["req_id"] + block_notify_list = data.get("block_notify_list", []) + decode_dp_rank = data.get("decode_rank", 0) + assert len(block_notify_list) > 0, ( + "block_notify_list cannot be empty in remote allocate message" + ) + + with self.lock: + self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo( + block_ids=block_notify_list, decode_dp_rank=decode_dp_rank + ) + + def _handle_completion_message(self, msg: str): + with self.lock: + if get_role() == ROLE.PRODUCER: + self.done_req_ids.append(msg) + else: + self.done_write_cache_req_ids.append(msg) + + def send_notify(self, req_ids, remote_ip, remote_port): + if not remote_ip or not remote_port: + logger.warning("Missing remote_ip or remote_port for notification") + return + + path = make_zmq_path("tcp", remote_ip, remote_port) + + if path not in self.paths: + ctx = zmq.Context.instance() + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) + self.paths[path] = sock + + req_list = req_ids if isinstance(req_ids, list) else [req_ids] + + sock = self.paths[path] + try: + for req_id in req_list: + if not isinstance(req_id, str): + logger.warning( + "Invalid req_id type: %s, expected str", type(req_id) + ) + continue + sock.send(req_id.encode("utf-8")) + except Exception as e: + logger.error("Failed to send notification to %s: %s", path, e) + self.paths.pop(path, None) + raise + + def pop_finished_req_ids(self): + # producer invocation: get the set of completed requests at the decode + with self.lock: + done_send = set(self.done_req_ids) + self.done_req_ids = [] + return done_send + + def pop_finished_write_req_ids(self): + # Call the consumer in write mode to get the collection after write completion + with self.lock: + done_write_cache = set(self.done_write_cache_req_ids) + self.done_write_cache_req_ids = [] + return done_write_cache + + def shutdown(self): + logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug("Closed ZMQ socket for path: %s", path) + except Exception as e: + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) + self.paths.clear()