From 9a15ae9f723b9bfc006e4da1b4e2c7eb8c944b74 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 19 Nov 2025 05:12:07 +0000 Subject: [PATCH 01/62] initial commit Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/factory.py | 5 + .../kv_connector/v1/moriio_connector.py | 2175 +++++++++++++++++ 2 files changed, 2180 insertions(+) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index df871dd7cbe4f..42821d341d6b6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -179,6 +179,11 @@ KVConnectorFactory.register_connector( "MultiConnector", ) +KVConnectorFactory.register_connector( + "MoRIIOConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.moriio_connector", + "MoRIIOConnector") + KVConnectorFactory.register_connector( "OffloadingConnector", "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py new file mode 100644 index 0000000000000..4e4daebd3ab77 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -0,0 +1,2175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import math +import os +import pickle +import queue +import threading +import time +from collections import defaultdict +from collections.abc import Iterator +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import msgpack +import msgspec +import numpy as np +import torch +import zmq + +from vllm import envs +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group, get_world_group) +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.platforms import _Backend +from vllm.utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus +from weakref import ref as weakref_ref + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +from dataclasses import field +from queue import Empty, Queue +from enum import Enum +import logging + + +logger = init_logger(__name__) + +Transfer = tuple[int, float] +EngineId = str +ReqId = str + + +class MoRIIOConstants: + """Constants for MoRIIO connector.""" + + # ZMQ message types + GET_META_MSG = b"get_meta_msg" + POP_DONE_RECV = b"pop_done_recv" + OVER = b"OVER" + COMPLETION_PREFIX = "cmpl" + + + # Default GPU count per node for standard configurations + RANK_PER_NODE = 8 + + + +try: + import mori + from mori.io import (BackendType, EngineDesc, IOEngine, IOEngineConfig, + MemoryDesc, StatusCode) + logger.info("MoRIIO is available") + MoRIIO_enabled = True +except ImportError: + logger.error("MoRIIO is not available") + MoRIIO_enabled = False + +@dataclass +class WriteTask: + request_id: str + dst_engine_id: str + local_block_ids: list[int] + remote_block_ids_hint: Optional[list[int]] + layer_name: str + event: torch.cuda.Event + remote_notify_port: int + remote_ip: str + enqueue_time: float = field(default_factory=time.perf_counter) + retried: int = 0 + +@dataclass +class LayerTransferPlan: + """Plan for transferring a single layer.""" + request_id: str + layer_name: str + sess_idx: int + transfer_local_offsets: list[int] + transfer_remote_offsets: list[int] + transfer_sizes: list[int] + use_batch: bool = True + +@dataclass +class RemoteAllocInfo: + """Information about remote block allocation.""" + block_ids: list[int] + writes_done: int = 0 + decode_dp_rank: int = 0 + transfer_offset: tuple[list[int], list[int], list[int]] | None = None + + +class ROLE(Enum): + PRODUCER = "producer" + CONSUMER = "consumer" + NOTINIT = "notinit" + +class RoleManager: + """Manages role state across the connector.""" + + _instance: Optional["RoleManager"] = None + _lock = threading.Lock() + + def __init__(self) -> None: + self._role: ROLE = ROLE.NOTINIT + + @classmethod + def get_instance(cls) -> "RoleManager": + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def set_role(self, role: ROLE) -> None: + """Set the current role.""" + with self._lock: + self._role = role + + def get_role(self) -> ROLE: + """Get the current role.""" + return self._role + + +def set_role(role: ROLE): + """Set the global role.""" + RoleManager.get_instance().set_role(role) + +def get_role() -> ROLE: + """Get the global role.""" + return RoleManager.get_instance().get_role() + + +class MoRIIOMode(Enum): + READ = "read" + WRITE = "write" + +class MoRIIOError(Exception): + """Base exception for MoRIIO operations.""" + pass + +class HandshakeError(MoRIIOError): + """Exception raised when handshake fails.""" + pass + +class TransferError(MoRIIOError): + """Exception raised when transfer fails.""" + pass + + +def get_moriio_mode() -> MoRIIOMode: + read_mode = os.environ.get('MORIIO_CONNECTOR_READ_MODE', 'false').lower() + # logger.info(f"MoRIIO Connector Read Mode = {read_mode}") + if read_mode in ('true', '1', 'yes', 'on'): + return MoRIIOMode.READ + else: + return MoRIIOMode.WRITE + + +def get_port_offset(dp_rank: int,tp_rank: int, tp_size:int=1) -> int: + return ((dp_rank)*tp_size+tp_rank )%MoRIIOConstants.RANK_PER_NODE + +@dataclass +class MoRIIOConfig: + local_ip: str + local_kv_port: int + proxy_ip: str + proxy_port: int + local_ping_port: int + proxy_ping_port: int + http_port: int + handshake_port: int + notify_port: int + tp_rank: int + dp_rank: int + dp_size: int + tp_size: int + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": + + + # Port Configuration: + # local_ping_port -> Outgoing heartbeat to proxy(only rank0 need it) + # proxy_ping_port -> Remote proxy's heartbeat ingress port + # http_port -> Instance's HTTP service endpoint + # local_kv_port -> KV service port for Mori engine + # notify_port -> For synchronizing stages between nodes + + kv_transfer_config = vllm_config.kv_transfer_config + extra_config = kv_transfer_config.kv_connector_extra_config + tp_rank = get_tensor_model_parallel_rank() + dp_rank = vllm_config.parallel_config.data_parallel_rank + base_kv_port = int(kv_transfer_config.kv_port) + base_ping_port = int(extra_config["local_ping_port"]) + base_notify_port = int(extra_config["notify_port"]) + dp_size=vllm_config.parallel_config.data_parallel_size + tp_size=get_tensor_model_parallel_world_size() + port_offset=get_port_offset(dp_rank,tp_rank) + + return cls( + local_ip=get_ip(), + local_kv_port=base_kv_port + port_offset, + proxy_ip=extra_config["proxy_ip"], + proxy_port=int(extra_config["proxy_port"]), + local_ping_port=base_ping_port+port_offset, + proxy_ping_port=int(extra_config["proxy_ping_port"]), + http_port=int(extra_config['http_port']), + handshake_port=int(extra_config['handshake_port']), + notify_port=base_notify_port + port_offset, + tp_rank=tp_rank, + dp_rank=dp_rank, + dp_size=dp_size, + tp_size=tp_size, + ) + + +"""Write task execution logic for MoRIIO connector.""" + +class MoRIIOWriter: + """Handles write operations for KV cache transfers. Implements distributed KV cache transfer using the MoRIIO library + for RDMA-based communication between prefill and decode instances.""" + + def __init__(self, worker: "MoRIIOConnectorWorker"): + """Initialize the writer. + + Args: + worker: Reference to the parent worker + """ + # self.worker = worker + self._worker_ref: "weakref_ref[MoRIIOConnectorWorker]" = weakref_ref(worker) + self._write_task_q: Queue[WriteTask] = Queue() + self._write_worker_started = False + self._write_worker_lock = threading.Lock() + self._deferred_tasks: list[WriteTask] = [] + + @property + def worker(self) -> "MoRIIOConnectorWorker": + """Get the worker instance. + + Returns: + The parent worker instance + + Raises: + RuntimeError: If worker has been garbage collected + """ + worker = self._worker_ref() + if worker is None: + raise RuntimeError("Parent worker has been garbage collected") + return worker + + def ensure_worker_started(self) -> None: + """Ensure the background write worker is running.""" + if self._write_worker_started: + return + self._write_worker_started = True + with self._write_worker_lock: + + thread = threading.Thread( + target=self._write_worker_loop, + daemon=True, + name="moriio-write-worker" + ) + thread.start() + logger.info("Started MoRIIO write worker thread") + + def schedule_write(self, task: WriteTask) -> None: + """Schedule a write task. + + Args: + task: The write task to schedule + """ + self.ensure_worker_started() + self._write_task_q.put(task) + + def _write_worker_loop(self) -> None: + """Main loop for the write worker thread.""" + logger.info("Write worker loop started") + + while True: + # Process deferred tasks first + self._process_deferred_tasks() + + # Get new task + try: + task = self._write_task_q.get( + timeout=0.01 + ) + except Empty: + continue + + # Check if remote blocks are ready + if not self._is_remote_ready(task): + # task.retry_count += 1 + self._deferred_tasks.append(task) + # logger.debug( + # "Deferred task for request %s (retry %d)", + # task.request_id, task.retry_count + # ) + continue + + # Execute the task + + self._execute_write_task(task) + + + def _process_deferred_tasks(self) -> None: + """Process tasks that were previously deferred.""" + if not self._deferred_tasks: + return + + still_deferred: list[WriteTask] = [] + for task in self._deferred_tasks: + if self._is_remote_ready(task): + + self._execute_write_task(task) + else: + still_deferred.append(task) + + self._deferred_tasks = still_deferred + + def _is_remote_ready(self, task: WriteTask) -> bool: + """Check if remote blocks are allocated for this task. + + Args: + task: The write task + + Returns: + True if remote blocks are ready + """ + return (task.request_id in + self.worker.moriio_wrapper.done_remote_allocate_req_dict) + + def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: + """Get remote allocation info for a request. + + Args: + request_id: The request ID + + Returns: + Remote allocation information + + Raises: + KeyError: If allocation info is missing + """ + try: + return self.worker.moriio_wrapper.done_remote_allocate_req_dict[ + request_id + ] + except KeyError as e: + raise KeyError( + f"Remote allocation info missing for request {request_id}" + ) from e + + def _execute_write_task(self, task: WriteTask) -> None: + """Execute a single write task. + + Args: + task: The write task to execute + + """ + # Get remote allocation info + request_info = self._get_remote_alloc_info(task.request_id) + + if request_info.block_ids is None: + logger.debug( + "Request %s remote block IDs not ready", + task.request_id + ) + return + + # Wait for CUDA event + task.event.synchronize() + + # Update engine ID with DP rank + task.dst_engine_id = ( + f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" + ) + + # Get or create sessions + sessions = self.worker._get_built_session(task.dst_engine_id) + + # Prepare transfer plan + plan = self._prepare_transfer_plan(task, request_info) + + # Execute transfer + self._do_layer_write(plan, sessions) + + # Finalize if all layers complete + self._finalize_if_complete(task, request_info) + + def _prepare_transfer_plan( + self, + task: WriteTask, + request_info: RemoteAllocInfo + ) -> LayerTransferPlan: + """Prepare the transfer plan for a layer. + + Args: + task: The write task + request_info: Remote allocation information + + Returns: + The transfer plan + """ + # Compute offsets if not cached + if request_info.transfer_offset is None: + offsets = self.worker._compute_block_transfer_offsets( + task.layer_name, + task.local_block_ids, + request_info.block_ids + ) + request_info.transfer_offset = offsets + + # Get session index + layer_names = list( + self.worker.layer_name_to_local_kv_cache_metadata.keys() + ) + sess_idx = layer_names.index(task.layer_name) + + local_off, remote_off, sizes = request_info.transfer_offset + + return LayerTransferPlan( + request_id=task.request_id, + layer_name=task.layer_name, + sess_idx=sess_idx, + transfer_local_offsets=local_off, + transfer_remote_offsets=remote_off, + transfer_sizes=sizes, + use_batch=True + ) + + def _do_layer_write( + self, + plan: LayerTransferPlan, + sessions: list + ) -> None: + """Perform the actual layer write. + + Args: + plan: The transfer plan + sessions: List of transfer sessions + """ + if plan.use_batch: + self.worker.moriio_wrapper.write_remote_data( + plan.transfer_sizes, + plan.transfer_local_offsets, + plan.transfer_remote_offsets, + sessions[plan.sess_idx] + ) + else: + for i in range(len(plan.transfer_local_offsets)): + self.worker.moriio_wrapper.write_remote_data_single( + plan.transfer_sizes[i], + plan.transfer_local_offsets[i], + plan.transfer_remote_offsets[i], + plan.sess_idx + ) + + def _finalize_if_complete( + self, + task: WriteTask, + request_info: RemoteAllocInfo + ) -> None: + """Finalize transfer if all layers are complete. + + Args: + task: The write task + request_info: Remote allocation information + """ + request_info.writes_done += 1 + + if request_info.writes_done >= self.worker.num_layers: + # Wait for transfer to complete + self.worker.moriio_wrapper.waiting_for_transfer_complete() + + + remote_port = task.remote_notify_port + get_port_offset( + request_info.decode_dp_rank, + self.worker.tp_rank + ) + # TODO: + # Consider using RDMA immediate data in decode side to eliminate the need for this notification. + # Consider including the first gen token from prefill in the notification + + # Send completion notification + self.worker.moriio_wrapper.send_notify( + task.request_id, + task.remote_ip, + remote_port + ) + del self.worker.moriio_wrapper.done_remote_allocate_req_dict[task.request_id] + logger.debug( + "Completed transfer for request %s, notified port %d", + task.request_id, remote_port + ) +class MoRIIOWrapper: + """Wrapper for MoRIIO engine operations. + + Handles both producer and consumer roles for KV cache transfers. + + Args: + moriio_engine: MoRIIO engine instance + tp_rank: Tensor parallel rank + dp_rank: Data parallel rank + """ + + def __init__(self, moriio_engine=None,tp_rank=0,dp_rank=0): + self.tp_rank=tp_rank + self.dp_rank=dp_rank + self.moriio_engine = moriio_engine + self.remote_memory_metadata = None + self.local_memory_registered = False + self.local_memory_metadata = None + self.transfer_status = [] + self.remote_engine_ip = None + self.notify_port = None + self.notify_sock = None + self.lock = threading.Lock() + self.done_req_ids = [] + self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} + self.done_write_cache_req_ids = [] + self.notify_thread = None + self.sock = None + self.sessions = [] + self.kv_caches = None + self.paths = {} + + + def set_moriio_engine(self, moriio_engine): + assert moriio_engine is not None, "You Cannot pass None engine to MoRIIOWrapper!" + self.moriio_engine = moriio_engine + + def set_backend_type(self, backend_type): + self.moriio_engine.create_backend(backend_type) + + def get_agent_metadata(self): + engine_metadata = self.moriio_engine.get_engine_desc() + engine_metadata_packed = engine_metadata.pack() + return engine_metadata_packed + + def register_remote_engine(self, remote_packed_engine_metadata): + consumer_engine_metadata = EngineDesc.unpack( + remote_packed_engine_metadata) + self.moriio_engine.register_remote_engine(consumer_engine_metadata) + return consumer_engine_metadata.key + + def register_local_tensor(self, tensor: torch.Tensor): + try: + self.local_memory_metadata = self.moriio_engine.register_torch_tensor( + tensor) + local_memory_metadata_packed = self.local_memory_metadata.pack() + except Exception as e: + raise MoRIIOError(f"Failed to register local memory: {e}") from e + self.local_memory_registered = True + return local_memory_metadata_packed + + def get_unpack_memory_metadata(self, packed_memory_metadata): + return MemoryDesc.unpack(packed_memory_metadata) + + def build_session(self, local_memory_metadata, remote_memory_metadata): + return self.moriio_engine.create_session(local_memory_metadata, + remote_memory_metadata) + + def read_remote_data(self, + transfer_size_byte, + local_offset=0, + remote_offset=0, + session=None): + assert self.local_memory_registered, "You have not register local memory data!" + + transfer_status = session.batch_read( + local_offset, remote_offset, transfer_size_byte, + self.moriio_engine.allocate_transfer_uid()) + + return transfer_status + + def write_remote_data(self, + transfer_size_byte, + local_offset=0, + remote_offset=0, + session=None): + assert self.local_memory_registered, "You have not register local memory data!" + write_uid = self.moriio_engine.allocate_transfer_uid() + + transfer_status = session.batch_write(local_offset, remote_offset, + transfer_size_byte, write_uid) + with self.lock: + self.transfer_status.append(transfer_status) + + def write_remote_data_single(self, + transfer_size_byte, + local_offset=0, + remote_offset=0, + sess_idx=0): + assert self.local_memory_registered, "You have not register local memory data!" + + transfer_status = self.sessions[sess_idx].write( + local_offset, remote_offset, transfer_size_byte, + self.moriio_engine.allocate_transfer_uid()) + with self.lock: + self.transfer_status.append(transfer_status) + + def waiting_for_transfer_complete(self): + if not self.transfer_status: + return + + transfers_to_wait = [] + with self.lock: + transfers_to_wait = self.transfer_status[:] + self.transfer_status.clear() + + for status in transfers_to_wait: + try: + status.Wait() + if not status.Succeeded(): + logger.error( + f"Transfer failed: {status.Message()}, Code: {status.Code()}" + ) + raise TransferError(f"MoRIIO transfer failed!") + except Exception as e: + logger.error(f"Transfer {status} failed: {e}") + raise + + def async_wait_reqid(self): + + + assert self.notify_port is not None, "Notify port cannot be None" + + if self.notify_thread is not None: + return + + def _async_wait(): + host = "*" + path = make_zmq_path("tcp", host, self.notify_port) + logger.info(f"Node starting to listen notify from path = {path}") + + with zmq_ctx(zmq.ROUTER, path) as sock: + while True: + try: + identity, msg = sock.recv_multipart() + self._handle_message(msg) + except Exception as e: + logger.error(f"Error processing message: {e}") + raise HandshakeError(f"Error processing message: {e}") from e + continue + + self.notify_thread = threading.Thread(target=_async_wait, + daemon=True, + name="moriio-notify-listener") + self.notify_thread.start() + + def _handle_message(self, msg: bytes): + """Handles incoming messages from remote nodes.""" + # Handles incoming remote messages: + # Prefill Role: + # [write] mode: receives block information (allocation) + # [read] mode: receives block release messages from decode side + # Decode Role: + # [write] mode: receives KV cache write completion notifications + handled = False + try: + data = msgpack.loads(msg) + if isinstance(data, dict) and "req_id" in data: + self._handle_structured_message(data) + handled = True + + return + except (msgpack.exceptions.ExtraData, + msgpack.exceptions.UnpackException): + pass + + try: + msg_str = msg.decode("UTF-8") + if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX): + self._handle_completion_message(msg_str) + handled = True + except UnicodeDecodeError: + logger.warning(f"Received non-UTF8 message: {msg}") + if not handled: + raise MoRIIOError(f"Unhandled message format: {msg}") + + def _handle_structured_message(self, data: dict): + req_id = data["req_id"] + block_notify_list = data.get("block_notify_list", []) + decode_dp_rank=data.get("decode_rank",0) + assert len(block_notify_list) > 0, "block_notify_list cannot be empty in remote allocate message" + + with self.lock: + self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(block_ids=block_notify_list,decode_dp_rank=decode_dp_rank) + + def _handle_completion_message(self, msg: str): + with self.lock: + if get_role() == ROLE.PRODUCER: + self.done_req_ids.append(msg) + else: + self.done_write_cache_req_ids.append(msg) + + def send_notify(self, req_ids, remote_ip=None, remote_port=None): + if not remote_ip or not remote_port: + logger.warning("Missing remote_ip or remote_port for notification") + return + + path = make_zmq_path("tcp", remote_ip, str(remote_port)) + + if path not in self.paths: + ctx = zmq.Context() + sock = make_zmq_socket(ctx=ctx, + path=path, + socket_type=zmq.DEALER, + bind=False) + self.paths[path] = sock + + req_list = req_ids if isinstance(req_ids, list) else [req_ids] + + sock = self.paths[path] + try: + for req_id in req_list: + if not isinstance(req_id, str): + logger.warning( + f"Invalid req_id type: {type(req_id)}, expected str") + continue + sock.send(req_id.encode("utf-8")) + except Exception as e: + logger.error(f"Failed to send notification to {path}: {e}") + self.paths.pop(path, None) + raise + + def pop_finished_req_ids(self): + # producer invocation: get the set of completed requests at the decode + with self.lock: + done_send = set(self.done_req_ids) + self.done_req_ids = [] + return done_send + + def pop_finished_write_req_ids(self): + # Call the consumer in write mode to get the collection after write completion + with self.lock: + done_write_cache = set(self.done_write_cache_req_ids) + self.done_write_cache_req_ids = [] + return done_write_cache + + + +class MoRIIOAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + block_len: int + attn_backend_name: str + + +@dataclass +class ReqMeta: + """Metadata for a single request.""" + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_handshake_port: int + remote_notify_port: int + remote_engine_id: str + tp_size: int + remote_dp_size: int + + +class MoRIIOConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_save: dict[ReqId, ReqMeta] = {} + self.reqs_to_send: dict[ReqId, float] = {} + + def __repr__(self): + return_str = "" + for req_id, req_meta in self.reqs_to_recv.items(): + return_str += f"{req_id = },{req_meta.local_block_ids = },{req_meta.remote_block_ids = },{req_meta.remote_host = },{req_meta.remote_port = },{req_meta.remote_engine_id = },{req_meta.tp_size = }" + return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," + + for req_id, req_meta in self.reqs_to_send.items(): + return_str += f"{req_id = },{req_meta = }" + return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str}," + return return_str + + def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + write_mode=False, + ): + + _req = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_handshake_port=kv_transfer_params['remote_handshake_port'], + remote_notify_port=kv_transfer_params.get('remote_notify_port'), + tp_size=kv_transfer_params.get("tp_size", 1), + remote_dp_size=kv_transfer_params.get("remote_dp_size", 1) + ) + if write_mode: + self.reqs_to_save[request_id] = _req + else: + self.reqs_to_recv[request_id] = _req + + +class MoRIIOConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + # assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id = str( + get_ip()) + ":" + str(vllm_config.kv_transfer_config. + kv_connector_extra_config['handshake_port']) + self.mode = get_moriio_mode() + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[MoRIIOConnectorScheduler] = \ + MoRIIOConnectorScheduler(vllm_config, self.engine_id) + self.connector_worker: Optional[MoRIIOConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MoRIIOConnectorWorker( + vllm_config, self.engine_id) + logger.info( + f"Initialized MoRIIO Connector,engine_id: {self.engine_id},role: {role.value}" + ) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens, self.connector_worker) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + + if self.mode == MoRIIOMode.WRITE: + if get_role() == ROLE.CONSUMER: + self.connector_worker.moriio_wrapper.async_wait_reqid() + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + + # Only producer/prefill saves KV Cache + if get_role() == ROLE.CONSUMER: + return + self.connector_worker.save_kv_layer(self._connector_metadata, + layer_name, kv_layer, + attn_metadata, **kwargs) + + return None + + def wait_for_save(self): + pass + + +class MoRIIOConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id: EngineId = engine_id + self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + self.mode=get_moriio_mode() + + + + self.handeshake_port=self.vllm_config.kv_transfer_config.kv_connector_extra_config['handshake_port'] + logger.info( + f"==========> Initializing MoRIIO Scheduler {engine_id = }" + ) + + self.side_notify_port = self.vllm_config.kv_transfer_config.kv_connector_extra_config[ + 'notify_port'] + self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size + self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank + self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" + # Requests that need to start recv/send. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} + + # For chunked prefill, we perform layer-wise access within the final chunk. + # TODO: Perform access at the end of each chunk. + self._reqs_need_pending_save: dict[ReqId, tuple[Request, list[int]]] = {} + + + if self.is_producer: + set_role(ROLE.PRODUCER) + else: + set_role(ROLE.CONSUMER) + # Reqs to send and their expiration time + self._reqs_need_send: dict[ReqId, float] = {} + self.sock = None + self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" + self.paths = {} + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + if self.is_producer: + return 0, False + + + if self.mode == MoRIIOMode.WRITE: + # MoriiO in write mode, no remote prefill + + return len(request.prompt_token_ids) - num_computed_tokens, True + + return len(request.prompt_token_ids) - 1 - num_computed_tokens, False + + def send_notify_block(self, + req_id: str, + block_notify_list: list[int] = None, + host=None, + port=None): + + path = make_zmq_path("tcp", host, port) + if path not in self.paths: + ctx = zmq.Context() + sock = make_zmq_socket(ctx=ctx, + path=path, + socket_type=zmq.DEALER, + bind=False) + self.paths[path] = sock + + data = { + "req_id": req_id, + "block_notify_list": block_notify_list or [], + "decode_rank": self.dp_rank, + "type": "remote_blocks" + } + # logger.debug(f"MoRIIO send notify block for prefill, {data= },{host= },{port= }") + serialized_data = msgpack.dumps(data) + self.paths[path].send(serialized_data) + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + connector_worker: Optional["MoRIIOConnectorWorker"] = None): + + params = request.kv_transfer_params + if params.get("do_remote_decode"): + local_block_ids = blocks.get_block_ids()[0] + self._reqs_need_save[request.request_id] = (request, + local_block_ids) + + if params is not None and params.get("do_remote_prefill"): + if self.mode == MoRIIOMode.READ: + if remote_block_ids := params.get("remote_block_ids"): + if all(p in params + for p in ("remote_engine_id", "remote_host", + "remote_port")): + # If remote_blocks and num_external_tokens = 0, we + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + + # Get unhashed blocks to pull from remote. + local_block_ids = blocks.get_block_ids()[0] + assert len(local_block_ids) <= len(remote_block_ids) + if len(local_block_ids) == len(remote_block_ids): + pass + else: + local_block_ids = remote_block_ids[ + -len(local_block_ids):] + + self._reqs_need_recv[request.request_id] = ( + request, local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + + else: + remote_dp_rank = request.kv_transfer_params.get('remote_dp_rank', 0) + + for tp_index in range(self.tp_size): + + target_port = request.kv_transfer_params[ + 'remote_notify_port'] + get_port_offset(remote_dp_rank, tp_index) + + + self.send_notify_block(req_id=request.request_id, + block_notify_list=blocks.get_block_ids()[0], + host=params.get("remote_host"), + port=target_port) + + # Only trigger 1 KV transfer per request. + + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MoRIIOConnectorMetadata() + + if self.mode == MoRIIOMode.WRITE: + # when async_load_kv finished, will add new reqs to scheduler_output.scheduled_new_reqs + + if get_role()== ROLE.CONSUMER: + for new_req in scheduler_output.scheduled_new_reqs: + red_id = new_req.req_id + local_block_ids = list(new_req.block_ids) + kv_transfer_params = new_req.sampling_params.extra_args[ + 'kv_transfer_params'] + meta.add_new_req( + red_id, + local_block_ids, + kv_transfer_params, + ) + if get_role()== ROLE.PRODUCER: + # This is the logic for checking against chunked prefill. + # When the last chunk is identified, it places the request metadata into the saving queue. + + for i,req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids): + new_block_ids = scheduler_output.scheduled_cached_reqs.new_block_ids[i] + + if new_block_ids is not None: + block_ids = new_block_ids[0] + + req, existing_blocks = self._reqs_need_pending_save[req_id] + updated_blocks = list(existing_blocks) + ([block_ids] if isinstance(block_ids, int) else block_ids) + self._reqs_need_pending_save[req_id] = (req, updated_blocks) + if len(self._reqs_need_pending_save[req_id][1]*self.block_size)>=req.num_prompt_tokens: + + meta.add_new_req( + request_id=req_id, + local_block_ids=self._reqs_need_pending_save[req_id][1], + kv_transfer_params=req.kv_transfer_params, + write_mode=True, + ) + del self._reqs_need_pending_save[req_id] + + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + for req_id, (req, block_ids) in self._reqs_need_save.items(): + assert req.kv_transfer_params is not None + if req.num_prompt_tokens>len(block_ids): + # not last chunk prefill + self._reqs_need_pending_save[req_id] = (req, block_ids) + continue + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + write_mode=True, + ) + # Clear the list once workers start the transfers + + meta.reqs_to_send = self._reqs_need_send + + self._reqs_need_recv.clear() + self._reqs_need_save.clear() + self._reqs_need_send = {} + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "MoriioConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + if not params: + return False, None + + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if (not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # computed_block_ids = block_ids if all_full else block_ids[:-1] + computed_block_ids = block_ids + # If prompt < block_size, no xfer so free blocks immediately. + delay_free_blocks = len(computed_block_ids) > 0 + + if delay_free_blocks: + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + + # If we execute in P-D serial mode, no notification port is needed. + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.handeshake_port, + tp_size=self.vllm_config.parallel_config.tensor_parallel_size) + + +class MoRIIOConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + if not MoRIIO_enabled: + raise RuntimeError( + "MoRIIO is not available. Please ensure the 'mori' package " + "is installed and properly configured." + ) + + self.moriio_config = MoRIIOConfig.from_vllm_config(vllm_config) + self.mode=get_moriio_mode() + + logger.info("Initializing MoRIIO worker %s", engine_id) + # for debug + logging.getLogger("aiter").disabled = True + + # Config. + self.vllm_config = vllm_config + self.kv_transfer_config = vllm_config.kv_transfer_config + self.is_producer = self.kv_transfer_config.is_kv_producer + + if self.is_producer: + set_role(ROLE.PRODUCER) + else: + set_role(ROLE.CONSUMER) + # mori engine + self._rank = get_world_group().rank + self._local_rank = get_world_group().local_rank + self.tp_rank = self.moriio_config.tp_rank + self.dp_rank= self.moriio_config.dp_rank + + logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }" + f",{self.is_producer= }") + + + self.local_ip = self.moriio_config.local_ip + self.local_kv_port=self.moriio_config.local_kv_port + self.proxy_ip = self.moriio_config.proxy_ip + self.proxy_port = self.moriio_config.proxy_port + self.local_ping_port = self.moriio_config.local_ping_port + self.proxy_ping_port =self.moriio_config.proxy_ping_port + self.http_port = self.moriio_config.http_port + self.handshake_port = self.moriio_config.handshake_port + self.notify_port = self.moriio_config.notify_port + + self.zmq_context = zmq.Context() + self.metadata_address = f"{self.moriio_config.local_ip}:{self.moriio_config.local_ping_port}" + self.request_address = f"{self.moriio_config.local_ip}:{self.moriio_config.http_port}" + + self.moriio_engine = None + self._handle_request_thread = None + self._ping_thread = None + self._writer = MoRIIOWriter(self) + + engine_suffix = (f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}" + f":tp {self.tp_rank}:dp {self.dp_rank}") + if not self.is_producer: + self.poller = zmq.Poller() + self.metadata_socket = self.zmq_context.socket(zmq.ROUTER) + self.metadata_socket.bind(f"tcp://{self.metadata_address}") + self.poller.register(self.metadata_socket, zmq.POLLIN) + + + self.moriio_engine = IOEngine( + "consumer:" + engine_suffix, + IOEngineConfig(self.moriio_config.local_ip, self.moriio_config.local_kv_port)) + + self._handle_request_thread = threading.Thread( + target=self.handle_proxy_request, daemon=True) + self._handle_request_thread.start() + else: + + self.moriio_engine = IOEngine( + "producer:" + engine_suffix, + IOEngineConfig(self.moriio_config.local_ip, self.moriio_config.local_kv_port)) + + logger.info("build IOEngine %s:%s", self.moriio_config.local_ip, self.moriio_config.local_kv_port) + + if self._rank == 0 and self.moriio_config.proxy_ip: + self._ping_thread = threading.Thread(target=self._ping, + args=(self.zmq_context, ), + daemon=True) + self._ping_thread.start() + + logger.info( + f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},role = {'producer' if self.is_producer else 'consumer'}" + ) + logger.debug( + f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.proxy_port = },{self.local_ping_port = },{self.proxy_ping_port = }" + ) + # Agent. + self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank,dp_rank=self.dp_rank) + + self.moriio_wrapper.set_moriio_engine(self.moriio_engine) + + self.moriio_wrapper.set_backend_type(BackendType.RDMA) + + self.moriio_wrapper.notify_port = self.moriio_config.notify_port + + + self.local_kv_cache_metadata = [] + self.local_kv_cache_size = [] + self.layer_name_to_local_kv_cache_metadata: dict[str, + List[Any]] = dict() + + self.remote_kv_cache_metadata = [] + self.remote_kv_cache_size = [] + self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[ + str, List[Any]]] = dict() + self.slot_size_bytes = 0 + + self.load_ready_flag = False + self.write_ready_flags = {} + self.kv_cache_shape = None + self.block_shape = None + self.kv_element_size = 0 + + self.done_sending_reqs = [] + self.done_send_threads = [] + + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) + + self.side_channel_port: int = ( + self.moriio_config.handshake_port + + get_port_offset(self.dp_rank,self.tp_rank) + ) + logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }") + logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}, han") + self.engine_id: EngineId = engine_id + + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # KV Caches and moriio tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + + # Number of MoRIIO regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + self.num_layers = 0 + + + + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. + self.dst_num_blocks: dict[EngineId, int] = {} + # In progress transfers. + self._recving_transfers:defaultdict[ReqId, list]=defaultdict(list) + self._recving_transfers_callback_addr: dict[ReqId, tuple[str,str]]= {} + + # Track the expiration time of requests that are waiting to be sent. + self._reqs_to_send: dict[ReqId, float] = {} + + # Background thread for handling new handshake requests. + self._moriio_handshake_listener_t: Optional[threading.Thread] = None + # Background thread for initializing new MoRIIO handshakes. + self._handshake_initiation_executor = ThreadPoolExecutor( + # MoRIIO is not guaranteed to be thread-safe, limit 1 worker. + max_workers=1, + thread_name_prefix="vllm-moriio-handshake-initiator") + self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() + self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + # Protects _handshake_futures and _remote_agents. + self._handshake_lock = threading.RLock() + + self.block_size = vllm_config.cache_config.block_size + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + # List of block window sizes for each layer for local attention + self.block_window_per_layer: list[Optional[int]] = [] + self.use_mla = self.model_config.use_mla + self.built_session = False + self.built_write_session: defaultdict[str, list] = defaultdict(list) + self._write_session_lock = threading.Lock() + self.debug_cache = [] + backend = get_attn_backend(self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla) + self.backend_name = backend.get_name() + attn_backend = backend_name_to_enum(self.backend_name) + self._use_flashinfer = attn_backend == _Backend.FLASHINFER + logger.debug("Detected attention backend %s", self.backend_name) + + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + + + ####write worker### + # self._write_task_q: Queue[WriteTask] = Queue() + # self._write_worker_started = False + # self._write_worker_lock = threading.Lock() + # self._deferred_tasks: list[WriteTask] = [] + # ####write worker### + + + + def schedule_write_blocks( + self, + request_id: str, + dst_engine_id: str, + local_block_ids: list[int], + remote_block_ids: Optional[list[int]], + layer_name: str, + kv_layer: torch.Tensor, + remote_notify_port: int, + remote_ip: str + ) -> None: + """Schedule a block write operation. + + Args: + request_id: Unique identifier for the request + dst_engine_id: Destination engine ID + local_block_ids: Local block IDs to transfer + remote_block_ids: Hint for remote block IDs + layer_name: Name of the layer + kv_layer: KV cache tensor + remote_notify_port: Port for completion notification + remote_ip: IP address of remote node + """ + + stream = torch.cuda.current_stream() + event = torch.cuda.Event() + event.record(stream) + + task = WriteTask(request_id=request_id, + dst_engine_id=dst_engine_id, + local_block_ids=local_block_ids, + remote_block_ids_hint=remote_block_ids, + layer_name=layer_name, + event=event, + remote_notify_port=remote_notify_port, + remote_ip=remote_ip) + self._writer.schedule_write(task) + + + + def _get_built_session(self, remote_engine_id): + if remote_engine_id not in self.built_write_session: + cur_remote_engine_sessions = [] + for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items( + ): + + unpcaked_local_memory_meta = self.moriio_wrapper.get_unpack_memory_metadata( + local_meta[0]) + unpcaked_remote_memory_meta = self.moriio_wrapper.get_unpack_memory_metadata( + self.layer_name_to_remote_kv_cache_metadata[ + remote_engine_id][ln][0]) + cur_remote_engine_sessions.append( + self.moriio_wrapper.build_session( + unpcaked_local_memory_meta, + unpcaked_remote_memory_meta)) + self.built_write_session[ + remote_engine_id] = cur_remote_engine_sessions + return self.built_write_session[remote_engine_id] + + def _ping(self, zmq_context): + PING_INTERVAL = 5 + MAX_RETRIES =100000 + + http_request_address = f"http://{self.request_address}/v1/completions" + role = "P" if self.is_producer else "D" + + retry_count = 0 + index = 1 + + with zmq_context.socket(zmq.DEALER) as sock: + sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}") + + while True: + try: + data = { + "type": "register", + "role": role, + "index": str(index), + "request_address": http_request_address, + "handshake_port": self.handshake_port, + "notify_port": self.notify_port, + "dp_size":self.moriio_config.dp_size, + "tp_size":self.moriio_config.tp_size, + "transfer_mode":self.mode.name, + } + + sock.send(msgpack.dumps(data)) + # logger.debug(f"Successfully sent ping message #{index}") + retry_count = 0 + + except ConnectionRefusedError: + logger.info( + f"Connection refused: {self.local_ip}:{self.local_ping_port} -> " + f"{self.proxy_ip}:{self.proxy_ping_port}" + ) + retry_count += 1 + + except OSError as e: + logger.info(f"OS error when sending ping: {e}") + retry_count += 1 + + except Exception as e: + logger.info(f"Unexpected error when sending ping: {e}") + retry_count += 1 + + finally: + if retry_count >= MAX_RETRIES: + logger.error(f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop.") + break + + time.sleep(PING_INTERVAL) + index += 1 + + def handle_proxy_request(self): + if self.is_producer: + raise NotImplementedError( + "prefill instance doesn't need to send kv cache in pull mode") + while True: + socks = dict(self.poller.poll()) + logger.debug(f"handle_proxy_request: {socks = }") + + #TODO: inkcherry , check here? + if self.metadata_socket not in socks: + continue + else: + pass + + def __del__(self): + """Cleanup background threads on destruction.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._moriio_handshake_listener_t: + self._moriio_handshake_listener_t.join(timeout=0) + + @staticmethod + def _moriio_handshake_listener( + metadata: MoRIIOAgentMetadata, ready_event: threading.Event, + base_port: int, tp_rank: int,dp_rank:int, + layer_name_to_local_kv_cache_metadata: dict): + """Background thread for getting new MoRIIO handshakes.""" + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded MoRIIOAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + host = "*" + logger.info(f"======> mori handeshake starting listening on baseport: {base_port}") + + path = make_zmq_path("tcp", host, base_port ) + logger.info(f"======> mori handeshake sstarting listening on path: {path}") + + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, msg = sock.recv_multipart() + if msg != MoRIIOConstants.GET_META_MSG and msg != MoRIIOConstants.POP_DONE_RECV: + logger.error( + "Connection listener got unexpected message %s", msg) + raise HandshakeError("handshake failed, unexpected msg type") + elif msg == MoRIIOConstants.GET_META_MSG: + sock.send_multipart( + (identity, b"", + encoded_data)) # send local mori io engine meta data + logger.info("MoRIIO handshake listener sent metadata to %s") + # now we send tensor meta data for each block + buf = pickle.dumps(layer_name_to_local_kv_cache_metadata) + sock.send_multipart((identity, b"", buf)) + elif msg == MoRIIOConstants.POP_DONE_RECV: + _, req_id = sock.recv_multipart() + logger.info("MoRIIO handshake listener received done recv for req %s", + req_id.decode()) + else: + pass + + def _moriio_handshake( + self, + host: str, + port: int, + remote_tp_size: int, + expected_engine_id: str, + remote_dp_rank:int=0, + ) -> dict[int, str]: + """Do a MoRIIO handshake with a remote instance.""" + + start_time = time.perf_counter() + + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + + + port_offset = get_port_offset(remote_dp_rank,self.tp_rank) + path = make_zmq_path("tcp", host, port + port_offset) + logger.info("handeshake Querying metadata on path: %s at remote rank %s", path,) + + # Send query for the request. + with zmq_ctx(zmq.DEALER, path) as sock: + logger.info(f"prepare send msg INSTAZNCE: {path}") + sock.send(MoRIIOConstants.GET_META_MSG) + received_frame = sock.recv_multipart() + if len(received_frame) != 2 or received_frame[0] != b"": + assert 0, f"unexpected frame! {received_frame = }" + + metadata_bytes = received_frame[1] + decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + logger.info("MoRIIO handshake: get metadata took: %s", + got_metadata_time - start_time) + + self.moriio_wrapper.remote_engine_ip = host + remote_agent_name = self.moriio_wrapper.register_remote_engine( + metadata.agent_metadata) + remote_agent_name=EngineDesc.unpack(metadata.agent_metadata).key + + logger.info(f"MoRIIO handshake: registered remote agent " + f"{remote_agent_name=} for engine ID " + f"{expected_engine_id=},f{path= }") + if len(self.local_kv_cache_metadata) > 0: + logger.warning( + f"{len(self.local_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" + ) + self.local_kv_cache_metadata = [] + if len(self.remote_kv_cache_metadata) > 0: + logger.warning( + f" {len(self.remote_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" + ) + self.remote_kv_cache_metadata = [] + + received_frame = sock.recv_multipart() + if len(received_frame) != 2 or received_frame[0] != b"": + assert 0, f"Unexpected frame! {received_frame = }" + buf = received_frame[1] + self.layer_name_to_remote_kv_cache_metadata[ + expected_engine_id] = pickle.loads(buf) + + setup_agent_time = time.perf_counter() + logger.debug("MoRIIO handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + + return {remote_agent_name} + + def _background_moriio_handshake(self, req_id: str, + remote_engine_id: EngineId, + meta: ReqMeta): + # Do MoRIIO handshake in background and add to _ready_requests when done. + fut = None + if remote_engine_id is not None: + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + host = meta.remote_host + port = int(meta.remote_handshake_port) + tp_size = int(meta.tp_size) + remote_dp_size = int(meta.remote_dp_size) + # TODO: handle failure state of future in the + # callback, we want to fail the request in this case. + def request_ready(_f: Future[Any], entry=(req_id, meta)): + logger.info("MoRIIO handshake done for request %s", req_id) + self._ready_requests.put(entry) + self.load_ready_flag = True + self.write_ready_flags[remote_engine_id] = True + + fut_list = [] + + # In dp(prefill)<->dp(decode) communication, we require an all-to-all handshake. + + for cur_dp_rank in range(remote_dp_size): + dp_engine_id = f"{remote_engine_id}_dp{cur_dp_rank}" + + future = self._handshake_initiation_executor.submit( + self._moriio_handshake, host, port, tp_size, dp_engine_id, cur_dp_rank + ) + fut_list.append(future) + + def done_callback(f: Future[dict[int, str]], eid=dp_engine_id): + with self._handshake_lock: + self._handshake_futures.pop(eid, None) + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception("Handshake with %s failed", eid) + + future.add_done_callback(done_callback) + self._handshake_futures[dp_engine_id] = future + + # fut = fut_list + def wait_all_dp(): + for future in fut_list: + future.result() + return True + + all_done_future = self._handshake_initiation_executor.submit(wait_all_dp) + all_done_future.add_done_callback(request_ready) + + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in moriio.""" + + # kv_caches,KEY layer name,VALUE cache tensor,(2,numblocks,blocksize,headnum,headsize) + _, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + use_mla = len(first_kv_cache.shape) == 3 + assert use_mla == self.use_mla + + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim + else: + # [2 (k and v), num_blocks, ...] + if self._use_flashinfer: + # FlashInfer swaps 2<->num_blocks dimensions. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 4 # [2, block_size, kv_heads, head_dim] + else: + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, n_kv_heads, head_dim = block_shape[-3:] + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim # 1 token 1 layer size , slot size + assert block_size == self.block_size + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + # block size in bytes + self.block_len = kv_elem_size * math.prod(block_shape) + self.kv_cache_shape = first_kv_cache.shape + self.block_shape = block_shape + self.kv_element_size = kv_elem_size + + # logger.info(f"Registering KV_Caches: {use_mla=}, {self.num_blocks=}, {block_shape=}, per_layer_kv_cache_shape={first_kv_cache.shape}") + + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches # layer name to kv cache + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded MoRIIOAgentMetadata is now larger + # (roughly 8KB vs 5KB). + # Conversely for FlashInfer, K and V are transferred in the same tensor + # to better exploit the memory layout (ie num_blocks is the first dim). + + for cache_or_caches in kv_caches.values(): + + cache_list = [ + cache_or_caches + ] if use_mla or self._use_flashinfer else cache_or_caches + # logger.debug(f"prepare register local kv cache tensor for local mori io engine,{len(cache_list) = },{kv_caches.keys() = }") + for cache in cache_list: + + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append( + (base_addr, region_len, cache.device.index, "")) + kv_caches_base_addr.append(base_addr) + + for layer_name, kv_cache in kv_caches.items(): + + if layer_name not in self.layer_name_to_local_kv_cache_metadata: + self.layer_name_to_local_kv_cache_metadata[layer_name] = [] + + # for cache in cache_list: + # moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache) + moriio_mem_metadata = self.moriio_wrapper.register_local_tensor( + kv_cache) + self.layer_name_to_local_kv_cache_metadata[layer_name].append( + moriio_mem_metadata) + + self.local_kv_cache_size.append(cache.nelement() * + cache.element_size()) + + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + self.num_layers = len(self.kv_caches.keys()) + + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + + metadata = MoRIIOAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.moriio_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + block_len=self.block_len, + attn_backend_name=self.backend_name) + ready_event = threading.Event() + self._moriio_handshake_listener_t = threading.Thread( + target=self._moriio_handshake_listener, + args=(metadata, ready_event, self.side_channel_port, self.tp_rank,self.dp_rank, + self.layer_name_to_local_kv_cache_metadata), + daemon=True, + name="moriio_handshake_listener") + self._moriio_handshake_listener_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. + self.moriio_wrapper.async_wait_reqid() + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving on this specific worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. + """ + + done_sending, done_recving = set(), set() + + if self.is_producer: + done_sending = self.moriio_wrapper.pop_finished_req_ids() + if self.mode == MoRIIOMode.WRITE: + done_recving = set() + else: + done_recving=self._pop_done_transfers() + else: + if self.mode == MoRIIOMode.WRITE: + self.moriio_wrapper.async_wait_reqid() + done_sending, done_recving = set( + ), self.moriio_wrapper.pop_finished_write_req_ids() + + return done_sending, done_recving + + + def _pop_done_transfers(self) -> set[str]: + + done_req_ids: set[str] = set() + for req_id, status_list in self._recving_transfers.items(): + if status_list[-1].Succeeded(): + done_req_ids.add(req_id) + + self.moriio_wrapper.send_notify( + req_id,self._recving_transfers_callback_addr[req_id][0], + self._recving_transfers_callback_addr[req_id][1]) + del self._recving_transfers[req_id] + del self._recving_transfers_callback_addr[req_id] + + return done_req_ids + + + def save_kv_layer(self, metadata: MoRIIOConnectorMetadata, layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs): + + if not self.is_producer: + return + if self.mode == MoRIIOMode.READ: + return + remote_engine_id = None + + + for req_id, meta in metadata.reqs_to_save.items(): + remote_engine_id = meta.remote_engine_id + # we only need to check if dp0 in rank + remote_engine_id = str(meta.remote_host) + ":" + str( + meta.remote_handshake_port) + + meta.remote_engine_id = remote_engine_id + + # TODO: mz get_remote_engine_id() for engine_id mapping. + dp0_remote_engine_id = f"{remote_engine_id}_dp0" + if dp0_remote_engine_id not in self._remote_agents: + # Initiate handshake with remote engine to exchange metadata. + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + logger.info( + f"*****background moriio {remote_engine_id = }") + self._background_moriio_handshake( + req_id, remote_engine_id, meta) + + continue + self._write_blocks_for_req(req_id, meta, layer_name, kv_layer) + + while True: + if remote_engine_id is None: + break + if self._ready_requests.empty() and remote_engine_id not in self.write_ready_flags: + continue + elif not self._ready_requests.empty() and (remote_engine_id + in self.write_ready_flags): + self._write_blocks_for_req(*self._ready_requests.get_nowait(), + layer_name, kv_layer) + break + else: + break + + def start_load_kv(self, metadata: MoRIIOConnectorMetadata): + """ + Start loading by triggering non-blocking moriio_xfer. + We check for these trnxs to complete in each step(). + """ + # print("start load kv") + if self.is_producer: + self.moriio_wrapper.async_wait_reqid() + return + if self.mode == MoRIIOMode.WRITE: + return + + wait_handshage_readd_req = False + remote_engine_id = None + + for req_id, meta in metadata.reqs_to_recv.items(): + remote_engine_id = str(meta.remote_host) + ":" + str( + meta.remote_handshake_port) + meta.remote_engine_id = remote_engine_id + dp0_remote_engine_id = f"{remote_engine_id}_dp0" + if dp0_remote_engine_id not in self._remote_agents: + # Initiate handshake with remote engine to exchange metadata. + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + self._background_moriio_handshake( + req_id, remote_engine_id, meta) + wait_handshage_readd_req = True + + continue + + # Handshake already completed, start async read xfer. + self._read_blocks_for_req(req_id, meta) + # Start transfers for requests whose handshakes have now finished. + + while True: #TODO + if self._ready_requests.empty( + ) and not self.load_ready_flag and wait_handshage_readd_req: + continue + elif not self._ready_requests.empty() and self.load_ready_flag: + self._read_blocks_for_req(*self._ready_requests.get_nowait()) + break + else: + break + + self._reqs_to_send.update(metadata.reqs_to_send) + + + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + logger.debug( + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, req_id) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_host=meta.remote_host, + remote_notify_port=meta.remote_notify_port, + ) + + def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, + kv_layer): + # logger.debug(f"write block for req {req_id} to remote engine " + # f"{meta.remote_engine_id}") + self.schedule_write_blocks(request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + layer_name=layer_name, + kv_layer=kv_layer, + remote_notify_port=meta.remote_notify_port, + remote_ip=meta.remote_host) + + def _is_last_layer(self, layer_name): + if layer_name == list(self.kv_caches.keys())[-1]: + return True + return False + + def _is_first_layer(self, layer_name): + if layer_name == list(self.kv_caches.keys())[0]: + return True + return False + + def merge_contiguous_blocks( + self, + offsets_local: List[int], + offsets_remote: List[int], + sizes: List[int], + assume_sorted: bool = False + ) -> Tuple[List[int], List[int], List[int]]: + n = len(offsets_local) + if n == 0: + return [], [], [] + if not (n == len(offsets_remote) == len(sizes)): + raise ValueError("Input list lengths mismatch") + local_arr = np.fromiter(offsets_local, dtype=np.int64, count=n) + remote_arr = np.fromiter(offsets_remote, dtype=np.int64, count=n) + sizes_arr = np.fromiter(sizes, dtype=np.int64, count=n) + + if assume_sorted: + local_sorted = local_arr + remote_sorted = remote_arr + sizes_sorted = sizes_arr + else: + if np.all(local_arr[:-1] <= local_arr[1:]): + local_sorted = local_arr + remote_sorted = remote_arr + sizes_sorted = sizes_arr + else: + sort_idx = np.argsort(local_arr, kind="stable") + local_sorted = local_arr[sort_idx] + remote_sorted = remote_arr[sort_idx] + sizes_sorted = sizes_arr[sort_idx] + + if n == 1: + return [int(local_sorted[0])], [int(remote_sorted[0]) + ], [int(sizes_sorted[0])] + + diff_local = local_sorted[1:] - local_sorted[:-1] + diff_remote = remote_sorted[1:] - remote_sorted[:-1] + prev_size = sizes_sorted[:-1] + + contiguous = (diff_local == prev_size) & (diff_remote == prev_size) + + if not contiguous.any(): + return local_sorted.tolist(), remote_sorted.tolist( + ), sizes_sorted.tolist() + + if contiguous.all(): + total_size = int(sizes_sorted.sum()) + return [int(local_sorted[0])], [int(remote_sorted[0]) + ], [total_size] + + break_positions = np.flatnonzero(~contiguous) + 1 + segment_starts = np.concatenate(([0], break_positions)) + segment_ends = np.concatenate((break_positions, [n])) + + seg_count = len(segment_starts) + merged_local = [0] * seg_count + merged_remote = [0] * seg_count + merged_sizes = [0] * seg_count + + for si in range(seg_count): + s = segment_starts[si] + e = segment_ends[si] + merged_local[si] = int(local_sorted[s]) + merged_remote[si] = int(remote_sorted[s]) + + merged_sizes[si] = int(local_sorted[e - 1] + sizes_sorted[e - 1] - + local_sorted[s]) + + return merged_local, merged_remote, merged_sizes + + def _compute_block_transfer_offsets( + self, + layer_name: str, + local_block_ids: list[int], + remote_block_ids: list[int], + ) -> tuple[list[int], list[int], list[int]]: + """Compute transfer offsets for block data. + + Args: + layer_name: Name of the layer to transfer + local_block_ids: IDs of local blocks + remote_block_ids: IDs of remote blocks + + Returns: + Tuple of (local_offsets, remote_offsets, transfer_sizes) + """ + is_mla = (len(self.kv_cache_shape) == 3) + stride = self.kv_caches[layer_name].stride() + sz = self.kv_caches[layer_name].element_size() + if is_mla: + blknum, blksize, hs = self.kv_cache_shape + hn = 1 + block_stride = stride[0] + ktov_stride = None + else: + _, blknum, blksize, hn, hs = self.kv_cache_shape + ktov_stride = stride[0] + block_stride = stride[1] + + transfer_size_byte = blksize * hn * hs * sz + per_block = 1 if is_mla else 2 + total = len(local_block_ids) * per_block + offset_local = [0] * total + offset_remote = [0] * total + sizes = [transfer_size_byte] * total + + w = 0 + for i, lb in enumerate(local_block_ids): + rb = remote_block_ids[i] + # K + offset_local[w] = sz * (lb * block_stride) + offset_remote[w] = sz * (rb * block_stride) + w += 1 + if not is_mla: + # V + offset_local[w] = sz * (1 * ktov_stride + lb * block_stride) + offset_remote[w] = sz * (1 * ktov_stride + rb * block_stride) + w += 1 + + merged_l, merged_r, merged_s = self.merge_contiguous_blocks( + offset_local, offset_remote, sizes, assume_sorted=True) + return merged_l, merged_r, merged_s + def _read_blocks(self, local_block_ids: list[int], + remote_block_ids: list[int], dst_engine_id: str, + request_id: str, + remote_host: str, + remote_notify_port: int)-> None: + + if self.mode == MoRIIOMode.WRITE: + return + # we only test TP<->TP in read mode + # assert self.dp_rank>0, "only test TP<->TP in read mode" + dst_engine_id+="_dp0" + sessions = self._get_built_session(dst_engine_id) + + first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] + offs = self._compute_block_transfer_offsets(first_layer, local_block_ids, remote_block_ids) + a, b, c = offs[0], offs[1], offs[2] + + for layer_name in self.layer_name_to_local_kv_cache_metadata.keys(): + sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(layer_name) + transfer_status=self.moriio_wrapper.read_remote_data(c, a, b, sessions[sess_idx]) + + self._recving_transfers[request_id].append(transfer_status) + self._recving_transfers_callback_addr[request_id]=(remote_host,remote_notify_port + self.tp_rank) + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) + finally: + if ctx is not None: + ctx.destroy(linger=0) \ No newline at end of file From b3e31b42d81a80015734442e6ccd6241e032d5c6 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 19 Nov 2025 05:21:06 +0000 Subject: [PATCH 02/62] update gitignore Signed-off-by: inkcherry --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7cda86478664f..9e8f9de7341b0 100644 --- a/.gitignore +++ b/.gitignore @@ -227,3 +227,4 @@ ep_kernels_workspace/ # Allow tracked library source folders under submodules (e.g., benchmarks/lib) !vllm/benchmarks/lib/ +examples/online_serving/disaggregated_serving_p2p_moriio_xpyd/ From a7ea23d16d955ad317240a9b017cd17bec357e92 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 19 Nov 2025 07:22:31 +0000 Subject: [PATCH 03/62] fix with new main branch Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 4e4daebd3ab77..35a686e7a8fd4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -20,7 +20,9 @@ import torch import zmq from vllm import envs -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.selector import get_attn_backend +from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -29,8 +31,7 @@ from vllm.distributed.parallel_state import ( get_tp_group, get_world_group) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.platforms import _Backend -from vllm.utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus from weakref import ref as weakref_ref @@ -38,6 +39,7 @@ from weakref import ref as weakref_ref if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request from dataclasses import field @@ -835,7 +837,7 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): class MoRIIOConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None,): assert vllm_config.kv_transfer_config is not None # assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id = str( @@ -927,6 +929,16 @@ class MoRIIOConnector(KVConnectorBase_V1): def wait_for_save(self): pass + def has_connector_metadata(self) -> bool: + """Check whether the connector metadata is currently set. + + Returns: + bool: True if connector metadata exists, False otherwise. + """ + try : + return self._connector_metadata is not None + except AttributeError: + return False class MoRIIOConnectorScheduler: @@ -1402,8 +1414,11 @@ class MoRIIOConnectorWorker: self.block_size, use_mla=self.use_mla) self.backend_name = backend.get_name() - attn_backend = backend_name_to_enum(self.backend_name) - self._use_flashinfer = attn_backend == _Backend.FLASHINFER + attn_backend = AttentionBackendEnum[self.backend_name] + self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS + # attn_backend = backend_name_to_enum(self.backend_name) + # self._use_flashinfer = attn_backend == _Backend.FLASHINFER logger.debug("Detected attention backend %s", self.backend_name) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} From 675943e018f12cbe5c37d0e00eb754f5ab08a3a0 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 19 Nov 2025 08:35:28 +0000 Subject: [PATCH 04/62] fix dp router Signed-off-by: inkcherry --- vllm/entrypoints/openai/serving_completion.py | 1 + vllm/entrypoints/openai/serving_engine.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a114b77ebc16b..4d47912d25315 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -225,6 +225,7 @@ class OpenAIServingCompletion(OpenAIServing): lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + data_parallel_rank=data_parallel_rank, ) generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c50b0c4a23e17..4d9903c9c5745 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1172,6 +1172,7 @@ class OpenAIServing: lora_request: LoRARequest | None, trace_headers: Mapping[str, str] | None, priority: int, + data_parallel_rank: int, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} @@ -1187,6 +1188,7 @@ class OpenAIServing: tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, + data_parallel_rank=data_parallel_rank, ) return engine_request, tokenization_kwargs From e0f4336a5b2833e6ca0c45382692f69e8f679fc8 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 19 Nov 2025 08:36:16 +0000 Subject: [PATCH 05/62] format Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/factory.py | 3 +- .../kv_connector/v1/moriio_connector.py | 1207 +++++++++-------- 2 files changed, 662 insertions(+), 548 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 42821d341d6b6..954a5153ff67d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -182,7 +182,8 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MoRIIOConnector", "vllm.distributed.kv_transfer.kv_connector.v1.moriio_connector", - "MoRIIOConnector") + "MoRIIOConnector", +) KVConnectorFactory.register_connector( "OffloadingConnector", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 35a686e7a8fd4..2856733484cab 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -11,7 +11,8 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional +from weakref import ref as weakref_ref import msgpack import msgspec @@ -20,21 +21,25 @@ import torch import zmq from vllm import envs -from vllm.attention.selector import get_attn_backend from vllm.attention.backends.registry import AttentionBackendEnum - +from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, get_world_group) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + get_world_group, +) from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus -from weakref import ref as weakref_ref if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -42,50 +47,55 @@ if TYPE_CHECKING: from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -from dataclasses import field -from queue import Empty, Queue -from enum import Enum import logging - +from dataclasses import field +from enum import Enum +from queue import Empty, Queue logger = init_logger(__name__) -Transfer = tuple[int, float] +Transfer = tuple[int, float] EngineId = str ReqId = str class MoRIIOConstants: """Constants for MoRIIO connector.""" - + # ZMQ message types GET_META_MSG = b"get_meta_msg" POP_DONE_RECV = b"pop_done_recv" OVER = b"OVER" COMPLETION_PREFIX = "cmpl" - # Default GPU count per node for standard configurations - RANK_PER_NODE = 8 - - - + RANK_PER_NODE = 8 + + try: import mori - from mori.io import (BackendType, EngineDesc, IOEngine, IOEngineConfig, - MemoryDesc, StatusCode) + from mori.io import ( + BackendType, + EngineDesc, + IOEngine, + IOEngineConfig, + MemoryDesc, + StatusCode, + ) + logger.info("MoRIIO is available") MoRIIO_enabled = True except ImportError: logger.error("MoRIIO is not available") MoRIIO_enabled = False - + + @dataclass class WriteTask: request_id: str dst_engine_id: str local_block_ids: list[int] - remote_block_ids_hint: Optional[list[int]] + remote_block_ids_hint: list[int] | None layer_name: str event: torch.cuda.Event remote_notify_port: int @@ -93,9 +103,11 @@ class WriteTask: enqueue_time: float = field(default_factory=time.perf_counter) retried: int = 0 + @dataclass class LayerTransferPlan: """Plan for transferring a single layer.""" + request_id: str layer_name: str sess_idx: int @@ -103,10 +115,12 @@ class LayerTransferPlan: transfer_remote_offsets: list[int] transfer_sizes: list[int] use_batch: bool = True - + + @dataclass class RemoteAllocInfo: """Information about remote block allocation.""" + block_ids: list[int] writes_done: int = 0 decode_dp_rank: int = 0 @@ -118,15 +132,16 @@ class ROLE(Enum): CONSUMER = "consumer" NOTINIT = "notinit" + class RoleManager: """Manages role state across the connector.""" - + _instance: Optional["RoleManager"] = None _lock = threading.Lock() - + def __init__(self) -> None: self._role: ROLE = ROLE.NOTINIT - + @classmethod def get_instance(cls) -> "RoleManager": if cls._instance is None: @@ -134,12 +149,12 @@ class RoleManager: if cls._instance is None: cls._instance = cls() return cls._instance - + def set_role(self, role: ROLE) -> None: """Set the current role.""" with self._lock: self._role = role - + def get_role(self) -> ROLE: """Get the current role.""" return self._role @@ -149,6 +164,7 @@ def set_role(role: ROLE): """Set the global role.""" RoleManager.get_instance().set_role(role) + def get_role() -> ROLE: """Get the global role.""" return RoleManager.get_instance().get_role() @@ -158,30 +174,37 @@ class MoRIIOMode(Enum): READ = "read" WRITE = "write" + class MoRIIOError(Exception): """Base exception for MoRIIO operations.""" + pass + class HandshakeError(MoRIIOError): """Exception raised when handshake fails.""" + pass + class TransferError(MoRIIOError): """Exception raised when transfer fails.""" + pass def get_moriio_mode() -> MoRIIOMode: - read_mode = os.environ.get('MORIIO_CONNECTOR_READ_MODE', 'false').lower() + read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower() # logger.info(f"MoRIIO Connector Read Mode = {read_mode}") - if read_mode in ('true', '1', 'yes', 'on'): + if read_mode in ("true", "1", "yes", "on"): return MoRIIOMode.READ else: return MoRIIOMode.WRITE -def get_port_offset(dp_rank: int,tp_rank: int, tp_size:int=1) -> int: - return ((dp_rank)*tp_size+tp_rank )%MoRIIOConstants.RANK_PER_NODE +def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: + return ((dp_rank) * tp_size + tp_rank) % MoRIIOConstants.RANK_PER_NODE + @dataclass class MoRIIOConfig: @@ -198,18 +221,16 @@ class MoRIIOConfig: dp_rank: int dp_size: int tp_size: int - + @classmethod def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": - - # Port Configuration: # local_ping_port -> Outgoing heartbeat to proxy(only rank0 need it) # proxy_ping_port -> Remote proxy's heartbeat ingress port # http_port -> Instance's HTTP service endpoint # local_kv_port -> KV service port for Mori engine # notify_port -> For synchronizing stages between nodes - + kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() @@ -217,53 +238,54 @@ class MoRIIOConfig: base_kv_port = int(kv_transfer_config.kv_port) base_ping_port = int(extra_config["local_ping_port"]) base_notify_port = int(extra_config["notify_port"]) - dp_size=vllm_config.parallel_config.data_parallel_size - tp_size=get_tensor_model_parallel_world_size() - port_offset=get_port_offset(dp_rank,tp_rank) - + dp_size = vllm_config.parallel_config.data_parallel_size + tp_size = get_tensor_model_parallel_world_size() + port_offset = get_port_offset(dp_rank, tp_rank) + return cls( local_ip=get_ip(), local_kv_port=base_kv_port + port_offset, proxy_ip=extra_config["proxy_ip"], proxy_port=int(extra_config["proxy_port"]), - local_ping_port=base_ping_port+port_offset, + local_ping_port=base_ping_port + port_offset, proxy_ping_port=int(extra_config["proxy_ping_port"]), - http_port=int(extra_config['http_port']), - handshake_port=int(extra_config['handshake_port']), + http_port=int(extra_config["http_port"]), + handshake_port=int(extra_config["handshake_port"]), notify_port=base_notify_port + port_offset, tp_rank=tp_rank, dp_rank=dp_rank, dp_size=dp_size, - tp_size=tp_size, + tp_size=tp_size, ) """Write task execution logic for MoRIIO connector.""" + class MoRIIOWriter: """Handles write operations for KV cache transfers. Implements distributed KV cache transfer using the MoRIIO library for RDMA-based communication between prefill and decode instances.""" - + def __init__(self, worker: "MoRIIOConnectorWorker"): """Initialize the writer. - + Args: worker: Reference to the parent worker """ # self.worker = worker - self._worker_ref: "weakref_ref[MoRIIOConnectorWorker]" = weakref_ref(worker) + self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) self._write_task_q: Queue[WriteTask] = Queue() self._write_worker_started = False self._write_worker_lock = threading.Lock() self._deferred_tasks: list[WriteTask] = [] - + @property def worker(self) -> "MoRIIOConnectorWorker": """Get the worker instance. - + Returns: The parent worker instance - + Raises: RuntimeError: If worker has been garbage collected """ @@ -271,47 +293,42 @@ class MoRIIOWriter: if worker is None: raise RuntimeError("Parent worker has been garbage collected") return worker - + def ensure_worker_started(self) -> None: """Ensure the background write worker is running.""" if self._write_worker_started: return self._write_worker_started = True with self._write_worker_lock: - thread = threading.Thread( - target=self._write_worker_loop, - daemon=True, - name="moriio-write-worker" + target=self._write_worker_loop, daemon=True, name="moriio-write-worker" ) thread.start() logger.info("Started MoRIIO write worker thread") - + def schedule_write(self, task: WriteTask) -> None: """Schedule a write task. - + Args: task: The write task to schedule """ self.ensure_worker_started() self._write_task_q.put(task) - + def _write_worker_loop(self) -> None: """Main loop for the write worker thread.""" logger.info("Write worker loop started") - + while True: # Process deferred tasks first self._process_deferred_tasks() - + # Get new task try: - task = self._write_task_q.get( - timeout=0.01 - ) + task = self._write_task_q.get(timeout=0.01) except Empty: continue - + # Check if remote blocks are ready if not self._is_remote_ready(task): # task.retry_count += 1 @@ -321,128 +338,114 @@ class MoRIIOWriter: # task.request_id, task.retry_count # ) continue - + # Execute the task self._execute_write_task(task) - - + def _process_deferred_tasks(self) -> None: """Process tasks that were previously deferred.""" if not self._deferred_tasks: return - + still_deferred: list[WriteTask] = [] for task in self._deferred_tasks: if self._is_remote_ready(task): - - self._execute_write_task(task) + self._execute_write_task(task) else: still_deferred.append(task) - + self._deferred_tasks = still_deferred - + def _is_remote_ready(self, task: WriteTask) -> bool: """Check if remote blocks are allocated for this task. - + Args: task: The write task - + Returns: True if remote blocks are ready """ - return (task.request_id in - self.worker.moriio_wrapper.done_remote_allocate_req_dict) - + return ( + task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict + ) + def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: """Get remote allocation info for a request. - + Args: request_id: The request ID - + Returns: Remote allocation information - + Raises: KeyError: If allocation info is missing """ try: - return self.worker.moriio_wrapper.done_remote_allocate_req_dict[ - request_id - ] + return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id] except KeyError as e: raise KeyError( f"Remote allocation info missing for request {request_id}" ) from e - + def _execute_write_task(self, task: WriteTask) -> None: """Execute a single write task. - + Args: task: The write task to execute - + """ # Get remote allocation info request_info = self._get_remote_alloc_info(task.request_id) - + if request_info.block_ids is None: - logger.debug( - "Request %s remote block IDs not ready", - task.request_id - ) + logger.debug("Request %s remote block IDs not ready", task.request_id) return - + # Wait for CUDA event task.event.synchronize() - + # Update engine ID with DP rank - task.dst_engine_id = ( - f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" - ) - + task.dst_engine_id = f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" + # Get or create sessions sessions = self.worker._get_built_session(task.dst_engine_id) - + # Prepare transfer plan plan = self._prepare_transfer_plan(task, request_info) - + # Execute transfer self._do_layer_write(plan, sessions) - + # Finalize if all layers complete self._finalize_if_complete(task, request_info) - + def _prepare_transfer_plan( - self, - task: WriteTask, - request_info: RemoteAllocInfo + self, task: WriteTask, request_info: RemoteAllocInfo ) -> LayerTransferPlan: """Prepare the transfer plan for a layer. - + Args: task: The write task request_info: Remote allocation information - + Returns: The transfer plan """ # Compute offsets if not cached if request_info.transfer_offset is None: offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, - task.local_block_ids, - request_info.block_ids + task.layer_name, task.local_block_ids, request_info.block_ids ) request_info.transfer_offset = offsets - + # Get session index - layer_names = list( - self.worker.layer_name_to_local_kv_cache_metadata.keys() - ) + layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys()) sess_idx = layer_names.index(task.layer_name) - + local_off, remote_off, sizes = request_info.transfer_offset - + return LayerTransferPlan( request_id=task.request_id, layer_name=task.layer_name, @@ -450,16 +453,12 @@ class MoRIIOWriter: transfer_local_offsets=local_off, transfer_remote_offsets=remote_off, transfer_sizes=sizes, - use_batch=True + use_batch=True, ) - - def _do_layer_write( - self, - plan: LayerTransferPlan, - sessions: list - ) -> None: + + def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None: """Perform the actual layer write. - + Args: plan: The transfer plan sessions: List of transfer sessions @@ -469,7 +468,7 @@ class MoRIIOWriter: plan.transfer_sizes, plan.transfer_local_offsets, plan.transfer_remote_offsets, - sessions[plan.sess_idx] + sessions[plan.sess_idx], ) else: for i in range(len(plan.transfer_local_offsets)): @@ -477,60 +476,59 @@ class MoRIIOWriter: plan.transfer_sizes[i], plan.transfer_local_offsets[i], plan.transfer_remote_offsets[i], - plan.sess_idx + plan.sess_idx, ) - + def _finalize_if_complete( - self, - task: WriteTask, - request_info: RemoteAllocInfo + self, task: WriteTask, request_info: RemoteAllocInfo ) -> None: """Finalize transfer if all layers are complete. - + Args: task: The write task request_info: Remote allocation information """ request_info.writes_done += 1 - + if request_info.writes_done >= self.worker.num_layers: # Wait for transfer to complete self.worker.moriio_wrapper.waiting_for_transfer_complete() - - + remote_port = task.remote_notify_port + get_port_offset( - request_info.decode_dp_rank, - self.worker.tp_rank + request_info.decode_dp_rank, self.worker.tp_rank ) # TODO: # Consider using RDMA immediate data in decode side to eliminate the need for this notification. # Consider including the first gen token from prefill in the notification - + # Send completion notification self.worker.moriio_wrapper.send_notify( - task.request_id, - task.remote_ip, - remote_port + task.request_id, task.remote_ip, remote_port ) - del self.worker.moriio_wrapper.done_remote_allocate_req_dict[task.request_id] + del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ + task.request_id + ] logger.debug( "Completed transfer for request %s, notified port %d", - task.request_id, remote_port + task.request_id, + remote_port, ) + + class MoRIIOWrapper: """Wrapper for MoRIIO engine operations. - + Handles both producer and consumer roles for KV cache transfers. - + Args: moriio_engine: MoRIIO engine instance tp_rank: Tensor parallel rank dp_rank: Data parallel rank """ - - def __init__(self, moriio_engine=None,tp_rank=0,dp_rank=0): - self.tp_rank=tp_rank - self.dp_rank=dp_rank + + def __init__(self, moriio_engine=None, tp_rank=0, dp_rank=0): + self.tp_rank = tp_rank + self.dp_rank = dp_rank self.moriio_engine = moriio_engine self.remote_memory_metadata = None self.local_memory_registered = False @@ -548,10 +546,11 @@ class MoRIIOWrapper: self.sessions = [] self.kv_caches = None self.paths = {} - def set_moriio_engine(self, moriio_engine): - assert moriio_engine is not None, "You Cannot pass None engine to MoRIIOWrapper!" + assert moriio_engine is not None, ( + "You Cannot pass None engine to MoRIIOWrapper!" + ) self.moriio_engine = moriio_engine def set_backend_type(self, backend_type): @@ -563,15 +562,15 @@ class MoRIIOWrapper: return engine_metadata_packed def register_remote_engine(self, remote_packed_engine_metadata): - consumer_engine_metadata = EngineDesc.unpack( - remote_packed_engine_metadata) + consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) self.moriio_engine.register_remote_engine(consumer_engine_metadata) return consumer_engine_metadata.key def register_local_tensor(self, tensor: torch.Tensor): try: self.local_memory_metadata = self.moriio_engine.register_torch_tensor( - tensor) + tensor + ) local_memory_metadata_packed = self.local_memory_metadata.pack() except Exception as e: raise MoRIIOError(f"Failed to register local memory: {e}") from e @@ -582,45 +581,47 @@ class MoRIIOWrapper: return MemoryDesc.unpack(packed_memory_metadata) def build_session(self, local_memory_metadata, remote_memory_metadata): - return self.moriio_engine.create_session(local_memory_metadata, - remote_memory_metadata) + return self.moriio_engine.create_session( + local_memory_metadata, remote_memory_metadata + ) - def read_remote_data(self, - transfer_size_byte, - local_offset=0, - remote_offset=0, - session=None): + def read_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): assert self.local_memory_registered, "You have not register local memory data!" transfer_status = session.batch_read( - local_offset, remote_offset, transfer_size_byte, - self.moriio_engine.allocate_transfer_uid()) + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) return transfer_status - def write_remote_data(self, - transfer_size_byte, - local_offset=0, - remote_offset=0, - session=None): + def write_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): assert self.local_memory_registered, "You have not register local memory data!" write_uid = self.moriio_engine.allocate_transfer_uid() - transfer_status = session.batch_write(local_offset, remote_offset, - transfer_size_byte, write_uid) + transfer_status = session.batch_write( + local_offset, remote_offset, transfer_size_byte, write_uid + ) with self.lock: self.transfer_status.append(transfer_status) - def write_remote_data_single(self, - transfer_size_byte, - local_offset=0, - remote_offset=0, - sess_idx=0): + def write_remote_data_single( + self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 + ): assert self.local_memory_registered, "You have not register local memory data!" transfer_status = self.sessions[sess_idx].write( - local_offset, remote_offset, transfer_size_byte, - self.moriio_engine.allocate_transfer_uid()) + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) with self.lock: self.transfer_status.append(transfer_status) @@ -640,14 +641,12 @@ class MoRIIOWrapper: logger.error( f"Transfer failed: {status.Message()}, Code: {status.Code()}" ) - raise TransferError(f"MoRIIO transfer failed!") + raise TransferError("MoRIIO transfer failed!") except Exception as e: logger.error(f"Transfer {status} failed: {e}") raise def async_wait_reqid(self): - - assert self.notify_port is not None, "Notify port cannot be None" if self.notify_thread is not None: @@ -668,9 +667,9 @@ class MoRIIOWrapper: raise HandshakeError(f"Error processing message: {e}") from e continue - self.notify_thread = threading.Thread(target=_async_wait, - daemon=True, - name="moriio-notify-listener") + self.notify_thread = threading.Thread( + target=_async_wait, daemon=True, name="moriio-notify-listener" + ) self.notify_thread.start() def _handle_message(self, msg: bytes): @@ -689,8 +688,7 @@ class MoRIIOWrapper: handled = True return - except (msgpack.exceptions.ExtraData, - msgpack.exceptions.UnpackException): + except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): pass try: @@ -706,11 +704,15 @@ class MoRIIOWrapper: def _handle_structured_message(self, data: dict): req_id = data["req_id"] block_notify_list = data.get("block_notify_list", []) - decode_dp_rank=data.get("decode_rank",0) - assert len(block_notify_list) > 0, "block_notify_list cannot be empty in remote allocate message" + decode_dp_rank = data.get("decode_rank", 0) + assert len(block_notify_list) > 0, ( + "block_notify_list cannot be empty in remote allocate message" + ) with self.lock: - self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(block_ids=block_notify_list,decode_dp_rank=decode_dp_rank) + self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo( + block_ids=block_notify_list, decode_dp_rank=decode_dp_rank + ) def _handle_completion_message(self, msg: str): with self.lock: @@ -728,10 +730,9 @@ class MoRIIOWrapper: if path not in self.paths: ctx = zmq.Context() - sock = make_zmq_socket(ctx=ctx, - path=path, - socket_type=zmq.DEALER, - bind=False) + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) self.paths[path] = sock req_list = req_ids if isinstance(req_ids, list) else [req_ids] @@ -740,8 +741,7 @@ class MoRIIOWrapper: try: for req_id in req_list: if not isinstance(req_id, str): - logger.warning( - f"Invalid req_id type: {type(req_id)}, expected str") + logger.warning(f"Invalid req_id type: {type(req_id)}, expected str") continue sock.send(req_id.encode("utf-8")) except Exception as e: @@ -764,12 +764,12 @@ class MoRIIOWrapper: return done_write_cache - class MoRIIOAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property.d - dict=True): + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True, +): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] @@ -781,6 +781,7 @@ class MoRIIOAgentMetadata( @dataclass class ReqMeta: """Metadata for a single request.""" + local_block_ids: list[int] remote_block_ids: list[int] remote_host: str @@ -793,7 +794,6 @@ class ReqMeta: class MoRIIOConnectorMetadata(KVConnectorMetadata): - def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} @@ -817,17 +817,16 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): kv_transfer_params: dict[str, Any], write_mode=False, ): - _req = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], - remote_handshake_port=kv_transfer_params['remote_handshake_port'], - remote_notify_port=kv_transfer_params.get('remote_notify_port'), + remote_handshake_port=kv_transfer_params["remote_handshake_port"], + remote_notify_port=kv_transfer_params.get("remote_notify_port"), tp_size=kv_transfer_params.get("tp_size", 1), - remote_dp_size=kv_transfer_params.get("remote_dp_size", 1) + remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), ) if write_mode: self.reqs_to_save[request_id] = _req @@ -836,43 +835,55 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): class MoRIIOConnector(KVConnectorBase_V1): - - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None,): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): assert vllm_config.kv_transfer_config is not None # assert vllm_config.kv_transfer_config.engine_id is not None - self.engine_id = str( - get_ip()) + ":" + str(vllm_config.kv_transfer_config. - kv_connector_extra_config['handshake_port']) + self.engine_id = ( + str(get_ip()) + + ":" + + str( + vllm_config.kv_transfer_config.kv_connector_extra_config[ + "handshake_port" + ] + ) + ) self.mode = get_moriio_mode() if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler: Optional[MoRIIOConnectorScheduler] = \ + self.connector_scheduler: MoRIIOConnectorScheduler | None = ( MoRIIOConnectorScheduler(vllm_config, self.engine_id) - self.connector_worker: Optional[MoRIIOConnectorWorker] = None + ) + self.connector_worker: MoRIIOConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = MoRIIOConnectorWorker( - vllm_config, self.engine_id) + self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) logger.info( f"Initialized MoRIIO Connector,engine_id: {self.engine_id},role: {role.value}" - ) + ) ############################################################ # Scheduler Side Methods ############################################################ def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens + ) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): assert self.connector_scheduler is not None return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens, self.connector_worker) + request, blocks, num_external_tokens, self.connector_worker + ) def build_connector_meta( self, @@ -885,7 +896,7 @@ class MoRIIOConnector(KVConnectorBase_V1): self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -896,15 +907,12 @@ class MoRIIOConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if self.mode == MoRIIOMode.WRITE: if get_role() == ROLE.CONSUMER: self.connector_worker.moriio_wrapper.async_wait_reqid() @@ -915,27 +923,32 @@ class MoRIIOConnector(KVConnectorBase_V1): def wait_for_layer_load(self, layer_name: str) -> None: pass - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: # Only producer/prefill saves KV Cache if get_role() == ROLE.CONSUMER: - return - self.connector_worker.save_kv_layer(self._connector_metadata, - layer_name, kv_layer, - attn_metadata, **kwargs) + return + self.connector_worker.save_kv_layer( + self._connector_metadata, layer_name, kv_layer, attn_metadata, **kwargs + ) return None def wait_for_save(self): pass + def has_connector_metadata(self) -> bool: """Check whether the connector metadata is currently set. Returns: bool: True if connector metadata exists, False otherwise. """ - try : + try: return self._connector_metadata is not None except AttributeError: return False @@ -949,17 +962,18 @@ class MoRIIOConnectorScheduler: self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - self.mode=get_moriio_mode() + self.mode = get_moriio_mode() - - - self.handeshake_port=self.vllm_config.kv_transfer_config.kv_connector_extra_config['handshake_port'] - logger.info( - f"==========> Initializing MoRIIO Scheduler {engine_id = }" + self.handeshake_port = ( + self.vllm_config.kv_transfer_config.kv_connector_extra_config[ + "handshake_port" + ] ) + logger.info(f"==========> Initializing MoRIIO Scheduler {engine_id = }") - self.side_notify_port = self.vllm_config.kv_transfer_config.kv_connector_extra_config[ - 'notify_port'] + self.side_notify_port = ( + self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] + ) self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" @@ -968,12 +982,11 @@ class MoRIIOConnectorScheduler: # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} - + # For chunked prefill, we perform layer-wise access within the final chunk. # TODO: Perform access at the end of each chunk. self._reqs_need_pending_save: dict[ReqId, tuple[Request, list[int]]] = {} - if self.is_producer: set_role(ROLE.PRODUCER) else: @@ -1006,7 +1019,6 @@ class MoRIIOConnectorScheduler: if self.is_producer: return 0, False - if self.mode == MoRIIOMode.WRITE: # MoriiO in write mode, no remote prefill @@ -1014,50 +1026,46 @@ class MoRIIOConnectorScheduler: return len(request.prompt_token_ids) - 1 - num_computed_tokens, False - def send_notify_block(self, - req_id: str, - block_notify_list: list[int] = None, - host=None, - port=None): - + def send_notify_block( + self, req_id: str, block_notify_list: list[int] = None, host=None, port=None + ): path = make_zmq_path("tcp", host, port) if path not in self.paths: ctx = zmq.Context() - sock = make_zmq_socket(ctx=ctx, - path=path, - socket_type=zmq.DEALER, - bind=False) + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) self.paths[path] = sock data = { "req_id": req_id, "block_notify_list": block_notify_list or [], "decode_rank": self.dp_rank, - "type": "remote_blocks" + "type": "remote_blocks", } # logger.debug(f"MoRIIO send notify block for prefill, {data= },{host= },{port= }") serialized_data = msgpack.dumps(data) self.paths[path].send(serialized_data) def update_state_after_alloc( - self, - request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int, - connector_worker: Optional["MoRIIOConnectorWorker"] = None): - + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + connector_worker: Optional["MoRIIOConnectorWorker"] = None, + ): params = request.kv_transfer_params if params.get("do_remote_decode"): local_block_ids = blocks.get_block_ids()[0] - self._reqs_need_save[request.request_id] = (request, - local_block_ids) + self._reqs_need_save[request.request_id] = (request, local_block_ids) if params is not None and params.get("do_remote_prefill"): if self.mode == MoRIIOMode.READ: if remote_block_ids := params.get("remote_block_ids"): - if all(p in params - for p in ("remote_engine_id", "remote_host", - "remote_port")): + if all( + p in params + for p in ("remote_engine_id", "remote_host", "remote_port") + ): # If remote_blocks and num_external_tokens = 0, we # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. @@ -1068,29 +1076,33 @@ class MoRIIOConnectorScheduler: if len(local_block_ids) == len(remote_block_ids): pass else: - local_block_ids = remote_block_ids[ - -len(local_block_ids):] + local_block_ids = remote_block_ids[-len(local_block_ids) :] self._reqs_need_recv[request.request_id] = ( - request, local_block_ids) + request, + local_block_ids, + ) else: logger.warning( "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", params) - + "request will not utilize KVTransfer", + params, + ) + else: - remote_dp_rank = request.kv_transfer_params.get('remote_dp_rank', 0) + remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) for tp_index in range(self.tp_size): - target_port = request.kv_transfer_params[ - 'remote_notify_port'] + get_port_offset(remote_dp_rank, tp_index) - + "remote_notify_port" + ] + get_port_offset(remote_dp_rank, tp_index) - self.send_notify_block(req_id=request.request_id, - block_notify_list=blocks.get_block_ids()[0], - host=params.get("remote_host"), - port=target_port) + self.send_notify_block( + req_id=request.request_id, + block_notify_list=blocks.get_block_ids()[0], + host=params.get("remote_host"), + port=target_port, + ) # Only trigger 1 KV transfer per request. @@ -1105,32 +1117,44 @@ class MoRIIOConnectorScheduler: if self.mode == MoRIIOMode.WRITE: # when async_load_kv finished, will add new reqs to scheduler_output.scheduled_new_reqs - if get_role()== ROLE.CONSUMER: + if get_role() == ROLE.CONSUMER: for new_req in scheduler_output.scheduled_new_reqs: red_id = new_req.req_id local_block_ids = list(new_req.block_ids) kv_transfer_params = new_req.sampling_params.extra_args[ - 'kv_transfer_params'] + "kv_transfer_params" + ] meta.add_new_req( red_id, local_block_ids, kv_transfer_params, ) - if get_role()== ROLE.PRODUCER: - # This is the logic for checking against chunked prefill. + if get_role() == ROLE.PRODUCER: + # This is the logic for checking against chunked prefill. # When the last chunk is identified, it places the request metadata into the saving queue. - - for i,req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids): - new_block_ids = scheduler_output.scheduled_cached_reqs.new_block_ids[i] - - if new_block_ids is not None: + + for i, req_id in enumerate( + scheduler_output.scheduled_cached_reqs.req_ids + ): + new_block_ids = ( + scheduler_output.scheduled_cached_reqs.new_block_ids[i] + ) + + if new_block_ids is not None: block_ids = new_block_ids[0] - + req, existing_blocks = self._reqs_need_pending_save[req_id] - updated_blocks = list(existing_blocks) + ([block_ids] if isinstance(block_ids, int) else block_ids) + updated_blocks = list(existing_blocks) + ( + [block_ids] if isinstance(block_ids, int) else block_ids + ) self._reqs_need_pending_save[req_id] = (req, updated_blocks) - if len(self._reqs_need_pending_save[req_id][1]*self.block_size)>=req.num_prompt_tokens: - + if ( + len( + self._reqs_need_pending_save[req_id][1] + * self.block_size + ) + >= req.num_prompt_tokens + ): meta.add_new_req( request_id=req_id, local_block_ids=self._reqs_need_pending_save[req_id][1], @@ -1138,8 +1162,7 @@ class MoRIIOConnectorScheduler: write_mode=True, ) del self._reqs_need_pending_save[req_id] - - + # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None @@ -1151,7 +1174,7 @@ class MoRIIOConnectorScheduler: for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - if req.num_prompt_tokens>len(block_ids): + if req.num_prompt_tokens > len(block_ids): # not last chunk prefill self._reqs_need_pending_save[req_id] = (req, block_ids) continue @@ -1175,7 +1198,7 @@ class MoRIIOConnectorScheduler: self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. @@ -1184,7 +1207,10 @@ class MoRIIOConnectorScheduler: params = request.kv_transfer_params logger.debug( "MoriioConnector request_finished, request_status=%s, " - "kv_transfer_params=%s", request.status, params) + "kv_transfer_params=%s", + request.status, + params, + ) if not params: return False, None @@ -1199,8 +1225,10 @@ class MoRIIOConnectorScheduler: params["do_remote_prefill"] = False return False, None - if (not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + if ( + not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED + ): return False, None # computed_block_ids = block_ids if all_full else block_ids[:-1] @@ -1210,9 +1238,10 @@ class MoRIIOConnectorScheduler: if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send[request.request_id] = time.perf_counter( - ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT - + self._reqs_need_send[request.request_id] = ( + time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + ) + # If we execute in P-D serial mode, no notification port is needed. return delay_free_blocks, dict( do_remote_prefill=True, @@ -1221,7 +1250,8 @@ class MoRIIOConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.handeshake_port, - tp_size=self.vllm_config.parallel_config.tensor_parallel_size) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + ) class MoRIIOConnectorWorker: @@ -1233,9 +1263,9 @@ class MoRIIOConnectorWorker: "MoRIIO is not available. Please ensure the 'mori' package " "is installed and properly configured." ) - + self.moriio_config = MoRIIOConfig.from_vllm_config(vllm_config) - self.mode=get_moriio_mode() + self.mode = get_moriio_mode() logger.info("Initializing MoRIIO worker %s", engine_id) # for debug @@ -1254,59 +1284,75 @@ class MoRIIOConnectorWorker: self._rank = get_world_group().rank self._local_rank = get_world_group().local_rank self.tp_rank = self.moriio_config.tp_rank - self.dp_rank= self.moriio_config.dp_rank - - logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }" - f",{self.is_producer= }") - - + self.dp_rank = self.moriio_config.dp_rank + + logger.info( + f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }" + f",{self.is_producer= }" + ) + self.local_ip = self.moriio_config.local_ip - self.local_kv_port=self.moriio_config.local_kv_port + self.local_kv_port = self.moriio_config.local_kv_port self.proxy_ip = self.moriio_config.proxy_ip self.proxy_port = self.moriio_config.proxy_port self.local_ping_port = self.moriio_config.local_ping_port - self.proxy_ping_port =self.moriio_config.proxy_ping_port + self.proxy_ping_port = self.moriio_config.proxy_ping_port self.http_port = self.moriio_config.http_port self.handshake_port = self.moriio_config.handshake_port self.notify_port = self.moriio_config.notify_port - + self.zmq_context = zmq.Context() - self.metadata_address = f"{self.moriio_config.local_ip}:{self.moriio_config.local_ping_port}" - self.request_address = f"{self.moriio_config.local_ip}:{self.moriio_config.http_port}" + self.metadata_address = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.local_ping_port}" + ) + self.request_address = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.http_port}" + ) self.moriio_engine = None self._handle_request_thread = None self._ping_thread = None self._writer = MoRIIOWriter(self) - - engine_suffix = (f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}" - f":tp {self.tp_rank}:dp {self.dp_rank}") + + engine_suffix = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}" + f":tp {self.tp_rank}:dp {self.dp_rank}" + ) if not self.is_producer: self.poller = zmq.Poller() self.metadata_socket = self.zmq_context.socket(zmq.ROUTER) self.metadata_socket.bind(f"tcp://{self.metadata_address}") self.poller.register(self.metadata_socket, zmq.POLLIN) - self.moriio_engine = IOEngine( "consumer:" + engine_suffix, - IOEngineConfig(self.moriio_config.local_ip, self.moriio_config.local_kv_port)) - + IOEngineConfig( + self.moriio_config.local_ip, self.moriio_config.local_kv_port + ), + ) + self._handle_request_thread = threading.Thread( - target=self.handle_proxy_request, daemon=True) + target=self.handle_proxy_request, daemon=True + ) self._handle_request_thread.start() else: - self.moriio_engine = IOEngine( "producer:" + engine_suffix, - IOEngineConfig(self.moriio_config.local_ip, self.moriio_config.local_kv_port)) - - logger.info("build IOEngine %s:%s", self.moriio_config.local_ip, self.moriio_config.local_kv_port) + IOEngineConfig( + self.moriio_config.local_ip, self.moriio_config.local_kv_port + ), + ) + + logger.info( + "build IOEngine %s:%s", + self.moriio_config.local_ip, + self.moriio_config.local_kv_port, + ) if self._rank == 0 and self.moriio_config.proxy_ip: - self._ping_thread = threading.Thread(target=self._ping, - args=(self.zmq_context, ), - daemon=True) + self._ping_thread = threading.Thread( + target=self._ping, args=(self.zmq_context,), daemon=True + ) self._ping_thread.start() logger.info( @@ -1316,24 +1362,23 @@ class MoRIIOConnectorWorker: f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.proxy_port = },{self.local_ping_port = },{self.proxy_ping_port = }" ) # Agent. - self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank,dp_rank=self.dp_rank) - + self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) + self.moriio_wrapper.set_moriio_engine(self.moriio_engine) self.moriio_wrapper.set_backend_type(BackendType.RDMA) self.moriio_wrapper.notify_port = self.moriio_config.notify_port - - + self.local_kv_cache_metadata = [] self.local_kv_cache_size = [] - self.layer_name_to_local_kv_cache_metadata: dict[str, - List[Any]] = dict() + self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() self.remote_kv_cache_metadata = [] self.remote_kv_cache_size = [] - self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[ - str, List[Any]]] = dict() + self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( + dict() + ) self.slot_size_bytes = 0 self.load_ready_flag = False @@ -1349,8 +1394,8 @@ class MoRIIOConnectorWorker: self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) self.side_channel_port: int = ( - self.moriio_config.handshake_port + - get_port_offset(self.dp_rank,self.tp_rank) + self.moriio_config.handshake_port + + get_port_offset(self.dp_rank, self.tp_rank) ) logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }") logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}, han") @@ -1371,25 +1416,24 @@ class MoRIIOConnectorWorker: self.num_regions = 0 self.num_layers = 0 - - # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. self.dst_num_blocks: dict[EngineId, int] = {} # In progress transfers. - self._recving_transfers:defaultdict[ReqId, list]=defaultdict(list) - self._recving_transfers_callback_addr: dict[ReqId, tuple[str,str]]= {} - + self._recving_transfers: defaultdict[ReqId, list] = defaultdict(list) + self._recving_transfers_callback_addr: dict[ReqId, tuple[str, str]] = {} + # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} # Background thread for handling new handshake requests. - self._moriio_handshake_listener_t: Optional[threading.Thread] = None + self._moriio_handshake_listener_t: threading.Thread | None = None # Background thread for initializing new MoRIIO handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # MoRIIO is not guaranteed to be thread-safe, limit 1 worker. max_workers=1, - thread_name_prefix="vllm-moriio-handshake-initiator") + thread_name_prefix="vllm-moriio-handshake-initiator", + ) self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} # Protects _handshake_futures and _remote_agents. @@ -1402,17 +1446,19 @@ class MoRIIOConnectorWorker: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) # List of block window sizes for each layer for local attention - self.block_window_per_layer: list[Optional[int]] = [] + self.block_window_per_layer: list[int | None] = [] self.use_mla = self.model_config.use_mla self.built_session = False self.built_write_session: defaultdict[str, list] = defaultdict(list) self._write_session_lock = threading.Lock() self.debug_cache = [] - backend = get_attn_backend(self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - use_mla=self.use_mla) + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) self.backend_name = backend.get_name() attn_backend = AttentionBackendEnum[self.backend_name] self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER @@ -1422,7 +1468,6 @@ class MoRIIOConnectorWorker: logger.debug("Detected attention backend %s", self.backend_name) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} - ####write worker### # self._write_task_q: Queue[WriteTask] = Queue() @@ -1431,21 +1476,19 @@ class MoRIIOConnectorWorker: # self._deferred_tasks: list[WriteTask] = [] # ####write worker### - - def schedule_write_blocks( self, request_id: str, dst_engine_id: str, local_block_ids: list[int], - remote_block_ids: Optional[list[int]], + remote_block_ids: list[int] | None, layer_name: str, kv_layer: torch.Tensor, remote_notify_port: int, - remote_ip: str + remote_ip: str, ) -> None: """Schedule a block write operation. - + Args: request_id: Unique identifier for the request dst_engine_id: Destination engine ID @@ -1461,50 +1504,53 @@ class MoRIIOConnectorWorker: event = torch.cuda.Event() event.record(stream) - task = WriteTask(request_id=request_id, - dst_engine_id=dst_engine_id, - local_block_ids=local_block_ids, - remote_block_ids_hint=remote_block_ids, - layer_name=layer_name, - event=event, - remote_notify_port=remote_notify_port, - remote_ip=remote_ip) + task = WriteTask( + request_id=request_id, + dst_engine_id=dst_engine_id, + local_block_ids=local_block_ids, + remote_block_ids_hint=remote_block_ids, + layer_name=layer_name, + event=event, + remote_notify_port=remote_notify_port, + remote_ip=remote_ip, + ) self._writer.schedule_write(task) - - def _get_built_session(self, remote_engine_id): if remote_engine_id not in self.built_write_session: cur_remote_engine_sessions = [] - for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items( - ): - - unpcaked_local_memory_meta = self.moriio_wrapper.get_unpack_memory_metadata( - local_meta[0]) - unpcaked_remote_memory_meta = self.moriio_wrapper.get_unpack_memory_metadata( - self.layer_name_to_remote_kv_cache_metadata[ - remote_engine_id][ln][0]) + for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items(): + unpcaked_local_memory_meta = ( + self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0]) + ) + unpcaked_remote_memory_meta = ( + self.moriio_wrapper.get_unpack_memory_metadata( + self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][ + ln + ][0] + ) + ) cur_remote_engine_sessions.append( self.moriio_wrapper.build_session( - unpcaked_local_memory_meta, - unpcaked_remote_memory_meta)) - self.built_write_session[ - remote_engine_id] = cur_remote_engine_sessions + unpcaked_local_memory_meta, unpcaked_remote_memory_meta + ) + ) + self.built_write_session[remote_engine_id] = cur_remote_engine_sessions return self.built_write_session[remote_engine_id] def _ping(self, zmq_context): PING_INTERVAL = 5 - MAX_RETRIES =100000 - + MAX_RETRIES = 100000 + http_request_address = f"http://{self.request_address}/v1/completions" role = "P" if self.is_producer else "D" - + retry_count = 0 index = 1 - + with zmq_context.socket(zmq.DEALER) as sock: sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}") - + while True: try: data = { @@ -1514,47 +1560,50 @@ class MoRIIOConnectorWorker: "request_address": http_request_address, "handshake_port": self.handshake_port, "notify_port": self.notify_port, - "dp_size":self.moriio_config.dp_size, - "tp_size":self.moriio_config.tp_size, - "transfer_mode":self.mode.name, + "dp_size": self.moriio_config.dp_size, + "tp_size": self.moriio_config.tp_size, + "transfer_mode": self.mode.name, } sock.send(msgpack.dumps(data)) # logger.debug(f"Successfully sent ping message #{index}") - retry_count = 0 - + retry_count = 0 + except ConnectionRefusedError: logger.info( f"Connection refused: {self.local_ip}:{self.local_ping_port} -> " f"{self.proxy_ip}:{self.proxy_ping_port}" ) retry_count += 1 - + except OSError as e: logger.info(f"OS error when sending ping: {e}") retry_count += 1 - + except Exception as e: logger.info(f"Unexpected error when sending ping: {e}") retry_count += 1 - + finally: if retry_count >= MAX_RETRIES: - logger.error(f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop.") + logger.error( + f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop." + ) break - + time.sleep(PING_INTERVAL) index += 1 def handle_proxy_request(self): if self.is_producer: raise NotImplementedError( - "prefill instance doesn't need to send kv cache in pull mode") + "prefill instance doesn't need to send kv cache in pull mode" + ) while True: socks = dict(self.poller.poll()) logger.debug(f"handle_proxy_request: {socks = }") - - #TODO: inkcherry , check here? + + # TODO: inkcherry , check here? if self.metadata_socket not in socks: continue else: @@ -1568,44 +1617,55 @@ class MoRIIOConnectorWorker: @staticmethod def _moriio_handshake_listener( - metadata: MoRIIOAgentMetadata, ready_event: threading.Event, - base_port: int, tp_rank: int,dp_rank:int, - layer_name_to_local_kv_cache_metadata: dict): + metadata: MoRIIOAgentMetadata, + ready_event: threading.Event, + base_port: int, + tp_rank: int, + dp_rank: int, + layer_name_to_local_kv_cache_metadata: dict, + ): """Background thread for getting new MoRIIO handshakes.""" encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) - logger.debug("Size of encoded MoRIIOAgentMetadata: %s bytes", - str(size_in_bytes)) + logger.debug( + "Size of encoded MoRIIOAgentMetadata: %s bytes", str(size_in_bytes) + ) # Listen for new requests for metadata. host = "*" - logger.info(f"======> mori handeshake starting listening on baseport: {base_port}") + logger.info( + f"======> mori handeshake starting listening on baseport: {base_port}" + ) - path = make_zmq_path("tcp", host, base_port ) + path = make_zmq_path("tcp", host, base_port) logger.info(f"======> mori handeshake sstarting listening on path: {path}") with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() while True: identity, msg = sock.recv_multipart() - if msg != MoRIIOConstants.GET_META_MSG and msg != MoRIIOConstants.POP_DONE_RECV: - logger.error( - "Connection listener got unexpected message %s", msg) + if ( + msg != MoRIIOConstants.GET_META_MSG + and msg != MoRIIOConstants.POP_DONE_RECV + ): + logger.error("Connection listener got unexpected message %s", msg) raise HandshakeError("handshake failed, unexpected msg type") elif msg == MoRIIOConstants.GET_META_MSG: sock.send_multipart( - (identity, b"", - encoded_data)) # send local mori io engine meta data + (identity, b"", encoded_data) + ) # send local mori io engine meta data logger.info("MoRIIO handshake listener sent metadata to %s") # now we send tensor meta data for each block buf = pickle.dumps(layer_name_to_local_kv_cache_metadata) sock.send_multipart((identity, b"", buf)) elif msg == MoRIIOConstants.POP_DONE_RECV: _, req_id = sock.recv_multipart() - logger.info("MoRIIO handshake listener received done recv for req %s", - req_id.decode()) + logger.info( + "MoRIIO handshake listener received done recv for req %s", + req_id.decode(), + ) else: pass @@ -1615,7 +1675,7 @@ class MoRIIOConnectorWorker: port: int, remote_tp_size: int, expected_engine_id: str, - remote_dp_rank:int=0, + remote_dp_rank: int = 0, ) -> dict[int, str]: """Do a MoRIIO handshake with a remote instance.""" @@ -1625,10 +1685,12 @@ class MoRIIOConnectorWorker: # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - - port_offset = get_port_offset(remote_dp_rank,self.tp_rank) + port_offset = get_port_offset(remote_dp_rank, self.tp_rank) path = make_zmq_path("tcp", host, port + port_offset) - logger.info("handeshake Querying metadata on path: %s at remote rank %s", path,) + logger.info( + "handeshake Querying metadata on path: %s at remote rank %s", + path, + ) # Send query for the request. with zmq_ctx(zmq.DEALER, path) as sock: @@ -1642,17 +1704,22 @@ class MoRIIOConnectorWorker: decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() - logger.info("MoRIIO handshake: get metadata took: %s", - got_metadata_time - start_time) + logger.info( + "MoRIIO handshake: get metadata took: %s", + got_metadata_time - start_time, + ) self.moriio_wrapper.remote_engine_ip = host remote_agent_name = self.moriio_wrapper.register_remote_engine( - metadata.agent_metadata) - remote_agent_name=EngineDesc.unpack(metadata.agent_metadata).key - - logger.info(f"MoRIIO handshake: registered remote agent " - f"{remote_agent_name=} for engine ID " - f"{expected_engine_id=},f{path= }") + metadata.agent_metadata + ) + remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key + + logger.info( + f"MoRIIO handshake: registered remote agent " + f"{remote_agent_name=} for engine ID " + f"{expected_engine_id=},f{path= }" + ) if len(self.local_kv_cache_metadata) > 0: logger.warning( f"{len(self.local_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" @@ -1668,18 +1735,21 @@ class MoRIIOConnectorWorker: if len(received_frame) != 2 or received_frame[0] != b"": assert 0, f"Unexpected frame! {received_frame = }" buf = received_frame[1] - self.layer_name_to_remote_kv_cache_metadata[ - expected_engine_id] = pickle.loads(buf) + self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( + pickle.loads(buf) + ) setup_agent_time = time.perf_counter() - logger.debug("MoRIIO handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + logger.debug( + "MoRIIO handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) return {remote_agent_name} - def _background_moriio_handshake(self, req_id: str, - remote_engine_id: EngineId, - meta: ReqMeta): + def _background_moriio_handshake( + self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta + ): # Do MoRIIO handshake in background and add to _ready_requests when done. fut = None if remote_engine_id is not None: @@ -1689,6 +1759,7 @@ class MoRIIOConnectorWorker: port = int(meta.remote_handshake_port) tp_size = int(meta.tp_size) remote_dp_size = int(meta.remote_dp_size) + # TODO: handle failure state of future in the # callback, we want to fail the request in this case. def request_ready(_f: Future[Any], entry=(req_id, meta)): @@ -1696,19 +1767,19 @@ class MoRIIOConnectorWorker: self._ready_requests.put(entry) self.load_ready_flag = True self.write_ready_flags[remote_engine_id] = True - + fut_list = [] - + # In dp(prefill)<->dp(decode) communication, we require an all-to-all handshake. for cur_dp_rank in range(remote_dp_size): dp_engine_id = f"{remote_engine_id}_dp{cur_dp_rank}" - + future = self._handshake_initiation_executor.submit( self._moriio_handshake, host, port, tp_size, dp_engine_id, cur_dp_rank ) fut_list.append(future) - + def done_callback(f: Future[dict[int, str]], eid=dp_engine_id): with self._handshake_lock: self._handshake_futures.pop(eid, None) @@ -1716,23 +1787,22 @@ class MoRIIOConnectorWorker: self._remote_agents[eid] = f.result() except Exception: logger.exception("Handshake with %s failed", eid) - + future.add_done_callback(done_callback) self._handshake_futures[dp_engine_id] = future - + # fut = fut_list def wait_all_dp(): for future in fut_list: - future.result() + future.result() return True all_done_future = self._handshake_initiation_executor.submit(wait_all_dp) all_done_future.add_done_callback(request_ready) - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in moriio.""" - + # kv_caches,KEY layer name,VALUE cache tensor,(2,numblocks,blocksize,headnum,headsize) _, first_kv_cache = next(iter(kv_caches.items())) kv_elem_size = first_kv_cache.element_size() @@ -1759,7 +1829,9 @@ class MoRIIOConnectorWorker: block_shape = first_kv_cache.shape[-block_rank:] block_size, n_kv_heads, head_dim = block_shape[-3:] # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim # 1 token 1 layer size , slot size + self.slot_size_bytes = ( + kv_elem_size * n_kv_heads * head_dim + ) # 1 token 1 layer size , slot size assert block_size == self.block_size # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc @@ -1784,35 +1856,32 @@ class MoRIIOConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - - for cache_or_caches in kv_caches.values(): - cache_list = [ - cache_or_caches - ] if use_mla or self._use_flashinfer else cache_or_caches + for cache_or_caches in kv_caches.values(): + cache_list = ( + [cache_or_caches] + if use_mla or self._use_flashinfer + else cache_or_caches + ) # logger.debug(f"prepare register local kv cache tensor for local mori io engine,{len(cache_list) = },{kv_caches.keys() = }") for cache in cache_list: - base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len - caches_data.append( - (base_addr, region_len, cache.device.index, "")) + caches_data.append((base_addr, region_len, cache.device.index, "")) kv_caches_base_addr.append(base_addr) for layer_name, kv_cache in kv_caches.items(): - if layer_name not in self.layer_name_to_local_kv_cache_metadata: self.layer_name_to_local_kv_cache_metadata[layer_name] = [] # for cache in cache_list: # moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache) - moriio_mem_metadata = self.moriio_wrapper.register_local_tensor( - kv_cache) + moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache) self.layer_name_to_local_kv_cache_metadata[layer_name].append( - moriio_mem_metadata) + moriio_mem_metadata + ) - self.local_kv_cache_size.append(cache.nelement() * - cache.element_size()) + self.local_kv_cache_size.append(cache.nelement() * cache.element_size()) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) @@ -1821,8 +1890,10 @@ class MoRIIOConnectorWorker: # Optimization for models with local attention (Llama 4) if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) + + assert isinstance( + self.vllm_config.model_config.hf_text_config, Llama4TextConfig + ) llama4_config = self.vllm_config.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size @@ -1833,8 +1904,10 @@ class MoRIIOConnectorWorker: is_local_attention = no_rope_layers[layer_idx] != 0 block_window = chunk_block_size if is_local_attention else None self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) + logger.debug( + "Llama 4 block window per layer mapping: %s", + self.block_window_per_layer, + ) assert len(self.block_window_per_layer) == self.num_layers metadata = MoRIIOAgentMetadata( @@ -1843,14 +1916,22 @@ class MoRIIOConnectorWorker: kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, block_len=self.block_len, - attn_backend_name=self.backend_name) + attn_backend_name=self.backend_name, + ) ready_event = threading.Event() self._moriio_handshake_listener_t = threading.Thread( target=self._moriio_handshake_listener, - args=(metadata, ready_event, self.side_channel_port, self.tp_rank,self.dp_rank, - self.layer_name_to_local_kv_cache_metadata), + args=( + metadata, + ready_event, + self.side_channel_port, + self.tp_rank, + self.dp_rank, + self.layer_name_to_local_kv_cache_metadata, + ), daemon=True, - name="moriio_handshake_listener") + name="moriio_handshake_listener", + ) self._moriio_handshake_listener_t.start() ready_event.wait() # Wait for listener ZMQ socket to be ready. self.moriio_wrapper.async_wait_reqid() @@ -1869,49 +1950,54 @@ class MoRIIOConnectorWorker: if self.mode == MoRIIOMode.WRITE: done_recving = set() else: - done_recving=self._pop_done_transfers() + done_recving = self._pop_done_transfers() else: if self.mode == MoRIIOMode.WRITE: self.moriio_wrapper.async_wait_reqid() - done_sending, done_recving = set( - ), self.moriio_wrapper.pop_finished_write_req_ids() + done_sending, done_recving = ( + set(), + self.moriio_wrapper.pop_finished_write_req_ids(), + ) return done_sending, done_recving - def _pop_done_transfers(self) -> set[str]: - done_req_ids: set[str] = set() for req_id, status_list in self._recving_transfers.items(): if status_list[-1].Succeeded(): done_req_ids.add(req_id) - + self.moriio_wrapper.send_notify( - req_id,self._recving_transfers_callback_addr[req_id][0], - self._recving_transfers_callback_addr[req_id][1]) + req_id, + self._recving_transfers_callback_addr[req_id][0], + self._recving_transfers_callback_addr[req_id][1], + ) del self._recving_transfers[req_id] del self._recving_transfers_callback_addr[req_id] - + return done_req_ids - - def save_kv_layer(self, metadata: MoRIIOConnectorMetadata, layer_name: str, - kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs): - + def save_kv_layer( + self, + metadata: MoRIIOConnectorMetadata, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ): if not self.is_producer: return if self.mode == MoRIIOMode.READ: return remote_engine_id = None - - + for req_id, meta in metadata.reqs_to_save.items(): remote_engine_id = meta.remote_engine_id # we only need to check if dp0 in rank - remote_engine_id = str(meta.remote_host) + ":" + str( - meta.remote_handshake_port) - + remote_engine_id = ( + str(meta.remote_host) + ":" + str(meta.remote_handshake_port) + ) + meta.remote_engine_id = remote_engine_id # TODO: mz get_remote_engine_id() for engine_id mapping. @@ -1920,10 +2006,10 @@ class MoRIIOConnectorWorker: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - logger.info( - f"*****background moriio {remote_engine_id = }") + logger.info(f"*****background moriio {remote_engine_id = }") self._background_moriio_handshake( - req_id, remote_engine_id, meta) + req_id, remote_engine_id, meta + ) continue self._write_blocks_for_req(req_id, meta, layer_name, kv_layer) @@ -1931,12 +2017,17 @@ class MoRIIOConnectorWorker: while True: if remote_engine_id is None: break - if self._ready_requests.empty() and remote_engine_id not in self.write_ready_flags: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.write_ready_flags + ): continue - elif not self._ready_requests.empty() and (remote_engine_id - in self.write_ready_flags): - self._write_blocks_for_req(*self._ready_requests.get_nowait(), - layer_name, kv_layer) + elif not self._ready_requests.empty() and ( + remote_engine_id in self.write_ready_flags + ): + self._write_blocks_for_req( + *self._ready_requests.get_nowait(), layer_name, kv_layer + ) break else: break @@ -1957,8 +2048,9 @@ class MoRIIOConnectorWorker: remote_engine_id = None for req_id, meta in metadata.reqs_to_recv.items(): - remote_engine_id = str(meta.remote_host) + ":" + str( - meta.remote_handshake_port) + remote_engine_id = ( + str(meta.remote_host) + ":" + str(meta.remote_handshake_port) + ) meta.remote_engine_id = remote_engine_id dp0_remote_engine_id = f"{remote_engine_id}_dp0" if dp0_remote_engine_id not in self._remote_agents: @@ -1966,7 +2058,8 @@ class MoRIIOConnectorWorker: with self._handshake_lock: if remote_engine_id not in self._remote_agents: self._background_moriio_handshake( - req_id, remote_engine_id, meta) + req_id, remote_engine_id, meta + ) wait_handshage_readd_req = True continue @@ -1975,9 +2068,12 @@ class MoRIIOConnectorWorker: self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. - while True: #TODO - if self._ready_requests.empty( - ) and not self.load_ready_flag and wait_handshage_readd_req: + while True: # TODO + if ( + self._ready_requests.empty() + and not self.load_ready_flag + and wait_handshage_readd_req + ): continue elif not self._ready_requests.empty() and self.load_ready_flag: self._read_blocks_for_req(*self._ready_requests.get_nowait()) @@ -1986,12 +2082,13 @@ class MoRIIOConnectorWorker: break self._reqs_to_send.update(metadata.reqs_to_send) - def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, req_id) + meta.remote_engine_id, + req_id, + ) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -2001,18 +2098,19 @@ class MoRIIOConnectorWorker: remote_notify_port=meta.remote_notify_port, ) - def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, - kv_layer): + def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer): # logger.debug(f"write block for req {req_id} to remote engine " # f"{meta.remote_engine_id}") - self.schedule_write_blocks(request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - layer_name=layer_name, - kv_layer=kv_layer, - remote_notify_port=meta.remote_notify_port, - remote_ip=meta.remote_host) + self.schedule_write_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + layer_name=layer_name, + kv_layer=kv_layer, + remote_notify_port=meta.remote_notify_port, + remote_ip=meta.remote_host, + ) def _is_last_layer(self, layer_name): if layer_name == list(self.kv_caches.keys())[-1]: @@ -2025,12 +2123,12 @@ class MoRIIOConnectorWorker: return False def merge_contiguous_blocks( - self, - offsets_local: List[int], - offsets_remote: List[int], - sizes: List[int], - assume_sorted: bool = False - ) -> Tuple[List[int], List[int], List[int]]: + self, + offsets_local: list[int], + offsets_remote: list[int], + sizes: list[int], + assume_sorted: bool = False, + ) -> tuple[list[int], list[int], list[int]]: n = len(offsets_local) if n == 0: return [], [], [] @@ -2056,8 +2154,11 @@ class MoRIIOConnectorWorker: sizes_sorted = sizes_arr[sort_idx] if n == 1: - return [int(local_sorted[0])], [int(remote_sorted[0]) - ], [int(sizes_sorted[0])] + return ( + [int(local_sorted[0])], + [int(remote_sorted[0])], + [int(sizes_sorted[0])], + ) diff_local = local_sorted[1:] - local_sorted[:-1] diff_remote = remote_sorted[1:] - remote_sorted[:-1] @@ -2066,13 +2167,11 @@ class MoRIIOConnectorWorker: contiguous = (diff_local == prev_size) & (diff_remote == prev_size) if not contiguous.any(): - return local_sorted.tolist(), remote_sorted.tolist( - ), sizes_sorted.tolist() + return local_sorted.tolist(), remote_sorted.tolist(), sizes_sorted.tolist() if contiguous.all(): total_size = int(sizes_sorted.sum()) - return [int(local_sorted[0])], [int(remote_sorted[0]) - ], [total_size] + return [int(local_sorted[0])], [int(remote_sorted[0])], [total_size] break_positions = np.flatnonzero(~contiguous) + 1 segment_starts = np.concatenate(([0], break_positions)) @@ -2089,8 +2188,9 @@ class MoRIIOConnectorWorker: merged_local[si] = int(local_sorted[s]) merged_remote[si] = int(remote_sorted[s]) - merged_sizes[si] = int(local_sorted[e - 1] + sizes_sorted[e - 1] - - local_sorted[s]) + merged_sizes[si] = int( + local_sorted[e - 1] + sizes_sorted[e - 1] - local_sorted[s] + ) return merged_local, merged_remote, merged_sizes @@ -2099,18 +2199,18 @@ class MoRIIOConnectorWorker: layer_name: str, local_block_ids: list[int], remote_block_ids: list[int], - ) -> tuple[list[int], list[int], list[int]]: + ) -> tuple[list[int], list[int], list[int]]: """Compute transfer offsets for block data. - + Args: layer_name: Name of the layer to transfer local_block_ids: IDs of local blocks remote_block_ids: IDs of remote blocks - + Returns: Tuple of (local_offsets, remote_offsets, transfer_sizes) """ - is_mla = (len(self.kv_cache_shape) == 3) + is_mla = len(self.kv_cache_shape) == 3 stride = self.kv_caches[layer_name].stride() sz = self.kv_caches[layer_name].element_size() if is_mla: @@ -2144,32 +2244,46 @@ class MoRIIOConnectorWorker: w += 1 merged_l, merged_r, merged_s = self.merge_contiguous_blocks( - offset_local, offset_remote, sizes, assume_sorted=True) + offset_local, offset_remote, sizes, assume_sorted=True + ) return merged_l, merged_r, merged_s - def _read_blocks(self, local_block_ids: list[int], - remote_block_ids: list[int], dst_engine_id: str, - request_id: str, - remote_host: str, - remote_notify_port: int)-> None: + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + remote_host: str, + remote_notify_port: int, + ) -> None: if self.mode == MoRIIOMode.WRITE: return # we only test TP<->TP in read mode # assert self.dp_rank>0, "only test TP<->TP in read mode" - dst_engine_id+="_dp0" + dst_engine_id += "_dp0" sessions = self._get_built_session(dst_engine_id) - + first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] - offs = self._compute_block_transfer_offsets(first_layer, local_block_ids, remote_block_ids) + offs = self._compute_block_transfer_offsets( + first_layer, local_block_ids, remote_block_ids + ) a, b, c = offs[0], offs[1], offs[2] for layer_name in self.layer_name_to_local_kv_cache_metadata.keys(): - sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(layer_name) - transfer_status=self.moriio_wrapper.read_remote_data(c, a, b, sessions[sess_idx]) - + sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( + layer_name + ) + transfer_status = self.moriio_wrapper.read_remote_data( + c, a, b, sessions[sess_idx] + ) + self._recving_transfers[request_id].append(transfer_status) - self._recving_transfers_callback_addr[request_id]=(remote_host,remote_notify_port + self.tp_rank) - + self._recving_transfers_callback_addr[request_id] = ( + remote_host, + remote_notify_port + self.tp_rank, + ) + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: @@ -2178,13 +2292,12 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): raise ValueError(f"Unexpected socket type: {socket_type}") - ctx: Optional[zmq.Context] = None + ctx: zmq.Context | None = None try: ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) finally: if ctx is not None: - ctx.destroy(linger=0) \ No newline at end of file + ctx.destroy(linger=0) From ecbad2a70b6160913968eb80a76b2eef6943a120 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 19 Nov 2025 10:01:44 +0000 Subject: [PATCH 06/62] add proxy example Signed-off-by: inkcherry --- .../moriio_integration/toy_proxy_server.py | 242 ++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 tests/v1/kv_connector/moriio_integration/toy_proxy_server.py diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py new file mode 100644 index 0000000000000..ea3127df57824 --- /dev/null +++ b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py @@ -0,0 +1,242 @@ +import argparse +import logging +import os +import socket +import uuid +import msgpack +import zmq +import copy +import threading +from quart import Quart, make_response, request +import re +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from typing import Dict,List +import asyncio +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +import aiohttp +prefill_instances = [] +decode_instances = [] +request_nums = 0 +app = Quart(__name__) + +yield_chunk = set() +IP_PORT_PATTERN = re.compile(r'//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)') + +from itertools import count + + + +TRANSFER_TYPE=None +def _append_whole_dict_unique(target_list, data_dict): + new_filtered = {k: v for k, v in data_dict.items() if k != "index"} + for existed in target_list: + existed_filtered = {k: v for k, v in existed.items() if k != "index"} + if existed_filtered == new_filtered: + return False + print("!!APPEND!!", data_dict) + target_list.append(data_dict) + transfer_mode = data_dict.get("transfer_mode", "unknown") + global TRANSFER_TYPE + + if TRANSFER_TYPE is None: + TRANSFER_TYPE = transfer_mode + logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) + elif TRANSFER_TYPE != transfer_mode: + raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") + + return True +_list_lock = threading.RLock() + +def _listen_for_register(hostname, port): + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + poller = zmq.Poller() + poller.register(router_socket,zmq.POLLIN) + global prefill_instances + global decode_instances + + while True: + socks = dict(poller.poll()) + if router_socket in socks: + + remote_addr,msg = router_socket.recv_multipart() + data = msgpack.loads(msg) + if data['type'] == "HELLO": + pass + elif data['type'] == "register" and data['role'] == "P": + if data['request_address'] not in prefill_instances: + with _list_lock: + _append_whole_dict_unique(prefill_instances, data) + + elif data["type"] == "register" and data['role'] == "D": + if data['request_address'] not in decode_instances: + with _list_lock: + _append_whole_dict_unique(decode_instances, data) + +def start_service_discovery(hostname, port): + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + + _listener_thread = threading.Thread( + target = _listen_for_register,args = (hostname, port),daemon=True + ) + _listener_thread.start() + return _listener_thread + +async def send_request_to_prefill(endpoint,req_data,request_id,p_endpoint,pip,pports,selected_prefill_dp_rank): + req_data_copy =req_data + + + req_data_copy['kv_transfer_params'].update({ + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_handshake_port": p_endpoint['handshake_port'], + "remote_notify_port":p_endpoint['notify_port'], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host":pip , + "remote_port": pports, + }) + req_data_copy["stream"] = False + req_data_copy["max_tokens"] = 1 + if "max_completion_tokens" in req_data_copy: + req_data_copy["max_completion_tokens"] = 1 + if "stream_options" in req_data_copy: + del req_data_copy["stream_options"] + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + if selected_prefill_dp_rank is not None: + headers['X-data-parallel-rank']=str(selected_prefill_dp_rank) + async with session.post(url=endpoint, json=req_data_copy, headers=headers) as response: + if response.status == 200: + return await response.json() + + else: + raise RuntimeError("send_request_to_prefill response.status != 200,response.statuus = ",response.status) +async def start_decode_request(endpoint, req_data, request_id): + session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + response = await session.post(url=endpoint, json=req_data, headers=headers) + return session, response + +async def stream_decode_response(session, response, request_id): + try: + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + + yield chunk_bytes + else: + raise RuntimeError(f"decode response.status != 200, status = {response.status}") + finally: + await session.close() +async def send_request_to_decode(endpoint,req_data,request_id): + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + async with session.post(url=endpoint, json=req_data, headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + + yield chunk_bytes + else: + raise RuntimeError("send_request_to_decode response.status != 200,response.statuus = ",response.status) +def example_round_robin_dp_loader(request_number, dp_size): + return request_nums % dp_size + +@app.route("/v1/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_request(): + try: + + global request_nums + request_nums += 1 + def extract_ip_port_fast(url): + return IP_PORT_PATTERN.search(url).groups() + req_data = await request.get_json() + request_id = str(uuid.uuid4()) + + prefill_instance_endpoint=None + decode_instance_endpoint=None + + pid=request_nums % len(prefill_instances) + did=request_nums % len(decode_instances) + prefill_instance_endpoint = prefill_instances[pid] + decode_instance_endpoint = decode_instances[did] + + + selected_prefill_dp_rank=None + if prefill_instance_endpoint['dp_size']>1: + selected_prefill_dp_rank=example_round_robin_dp_loader(request_nums//len(prefill_instance_endpoint),prefill_instance_endpoint['dp_size']) + + dip,dport= extract_ip_port_fast(decode_instance_endpoint['request_address']) + ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) + + req_data_to_prefill = copy.deepcopy(req_data) + req_data_to_prefill['kv_transfer_params']={} + req_data['kv_transfer_params']={} + req_data_to_prefill['kv_transfer_params']['remote_dp_size']=decode_instance_endpoint['dp_size'] + req_data_to_prefill['kv_transfer_params']['remote_tp_size']=decode_instance_endpoint['tp_size'] + + + + send_prefill_task = asyncio.create_task(send_request_to_prefill(prefill_instance_endpoint['request_address'],req_data_to_prefill,request_id,decode_instance_endpoint,dip,dport,selected_prefill_dp_rank)) + ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) + + + req_data['max_tokens'] -= 1 + + req_data['kv_transfer_params'] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_handshake_port": prefill_instance_endpoint['handshake_port'], + "remote_notify_port":prefill_instance_endpoint['notify_port'], + "remote_engine_id":None, + "remote_block_ids":None, + "remote_host":ip , + "remote_port": port, + } + if TRANSFER_TYPE =="READ": + #In read mode, prefill and decode are executed serially. + prefill_response=await send_prefill_task + req_data['kv_transfer_params']['remote_engine_id']=prefill_response['kv_transfer_params']['remote_engine_id'] + req_data['kv_transfer_params']['remote_block_ids']=prefill_response['kv_transfer_params']['remote_block_ids'] + + req_data['kv_transfer_params']['remote_dp_size'] = prefill_instance_endpoint['dp_size'] + req_data['kv_transfer_params']['remote_tp_size'] = prefill_instance_endpoint['tp_size'] + + + decode_request_task = asyncio.create_task( + start_decode_request(decode_instance_endpoint['request_address'], req_data, request_id) + ) + + + session, decode_response = await decode_request_task + stream_generator = stream_decode_response(session, decode_response, request_id) + response = await make_response(stream_generator) + return response + except Exception as e: + print(e) + pass + +if __name__ == '__main__': + t = start_service_discovery("0.0.0.0", 36367) + app.debug = True + app.config['BODY_TIMEOUT'] = 360000 + app.config['RESPONSE_TIMEOUT'] = 360000 + + app.run(host="0.0.0.0", port=10001) + t.join() \ No newline at end of file From f8e9adfea85e6dd623aa94775a63064f1347b6d7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 03:39:42 +0000 Subject: [PATCH 07/62] refine Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 2856733484cab..697ae3a786271 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -47,10 +47,10 @@ if TYPE_CHECKING: from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -import logging from dataclasses import field from enum import Enum from queue import Empty, Queue +import logging logger = init_logger(__name__) @@ -964,7 +964,7 @@ class MoRIIOConnectorScheduler: self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.mode = get_moriio_mode() - self.handeshake_port = ( + self.handshake_port = ( self.vllm_config.kv_transfer_config.kv_connector_extra_config[ "handshake_port" ] @@ -1174,7 +1174,7 @@ class MoRIIOConnectorScheduler: for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - if req.num_prompt_tokens > len(block_ids): + if req.num_prompt_tokens > len(block_ids)*self.block_size: # not last chunk prefill self._reqs_need_pending_save[req_id] = (req, block_ids) continue @@ -1249,7 +1249,7 @@ class MoRIIOConnectorScheduler: remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, remote_host=self.side_channel_host, - remote_port=self.handeshake_port, + remote_port=self.handshake_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, ) @@ -1398,7 +1398,7 @@ class MoRIIOConnectorWorker: + get_port_offset(self.dp_rank, self.tp_rank) ) logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }") - logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}, han") + logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}") self.engine_id: EngineId = engine_id self.world_size = get_tensor_model_parallel_world_size() @@ -1606,8 +1606,7 @@ class MoRIIOConnectorWorker: # TODO: inkcherry , check here? if self.metadata_socket not in socks: continue - else: - pass + def __del__(self): """Cleanup background threads on destruction.""" @@ -1636,11 +1635,11 @@ class MoRIIOConnectorWorker: # Listen for new requests for metadata. host = "*" logger.info( - f"======> mori handeshake starting listening on baseport: {base_port}" + f"======> mori handshake starting listening on baseport: {base_port}" ) path = make_zmq_path("tcp", host, base_port) - logger.info(f"======> mori handeshake sstarting listening on path: {path}") + logger.info(f"======> mori handshake starting listening on path: {path}") with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() @@ -1688,7 +1687,7 @@ class MoRIIOConnectorWorker: port_offset = get_port_offset(remote_dp_rank, self.tp_rank) path = make_zmq_path("tcp", host, port + port_offset) logger.info( - "handeshake Querying metadata on path: %s at remote rank %s", + "handshake Querying metadata on path: %s at remote rank %s", path, ) From 68a23333399b02c8f2d3eb0e12afbfdcc059ed13 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 03:55:48 +0000 Subject: [PATCH 08/62] fix dp proxy Signed-off-by: inkcherry --- tests/v1/kv_connector/moriio_integration/toy_proxy_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py index ea3127df57824..e375b84bd4053 100644 --- a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py @@ -218,7 +218,9 @@ async def handle_request(): req_data['kv_transfer_params']['remote_dp_size'] = prefill_instance_endpoint['dp_size'] req_data['kv_transfer_params']['remote_tp_size'] = prefill_instance_endpoint['tp_size'] - + if selected_prefill_dp_rank is not None: + req_data['kv_transfer_params']['remote_dp_rank'] = selected_prefill_dp_rank + decode_request_task = asyncio.create_task( start_decode_request(decode_instance_endpoint['request_address'], req_data, request_id) ) From 70ea1b2460d67767d83f6a7f194e231c932f1fcf Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 04:36:19 +0000 Subject: [PATCH 09/62] refine code Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 80 +++++++------------ 1 file changed, 28 insertions(+), 52 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 697ae3a786271..603efc0507bed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -7,6 +7,13 @@ import pickle import queue import threading import time +import msgpack +import msgspec +import numpy as np +import torch +import zmq +import logging + from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor @@ -14,12 +21,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional from weakref import ref as weakref_ref -import msgpack -import msgspec -import numpy as np -import torch -import zmq - from vllm import envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend @@ -50,7 +51,6 @@ if TYPE_CHECKING: from dataclasses import field from enum import Enum from queue import Empty, Queue -import logging logger = init_logger(__name__) @@ -70,7 +70,9 @@ class MoRIIOConstants: # Default GPU count per node for standard configurations RANK_PER_NODE = 8 - + + PING_INTERVAL = 5 + MAX_PING_RETRIES = 100000 try: import mori @@ -80,7 +82,6 @@ try: IOEngine, IOEngineConfig, MemoryDesc, - StatusCode, ) logger.info("MoRIIO is available") @@ -404,10 +405,12 @@ class MoRIIOWriter: return # Wait for CUDA event + #The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. + #This event is used to synchronize the kv transfer and computation tasks. task.event.synchronize() # Update engine ID with DP rank - task.dst_engine_id = f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" + task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank) # Get or create sessions sessions = self.worker._get_built_session(task.dst_engine_id) @@ -665,7 +668,6 @@ class MoRIIOWrapper: except Exception as e: logger.error(f"Error processing message: {e}") raise HandshakeError(f"Error processing message: {e}") from e - continue self.notify_thread = threading.Thread( target=_async_wait, daemon=True, name="moriio-notify-listener" @@ -685,7 +687,6 @@ class MoRIIOWrapper: data = msgpack.loads(msg) if isinstance(data, dict) and "req_id" in data: self._handle_structured_message(data) - handled = True return except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): @@ -841,6 +842,7 @@ class MoRIIOConnector(KVConnectorBase_V1): role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None, ): + super().__init__(vllm_config, role) assert vllm_config.kv_transfer_config is not None # assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id = ( @@ -1149,10 +1151,8 @@ class MoRIIOConnectorScheduler: ) self._reqs_need_pending_save[req_id] = (req, updated_blocks) if ( - len( - self._reqs_need_pending_save[req_id][1] + len(self._reqs_need_pending_save[req_id][1]) * self.block_size - ) >= req.num_prompt_tokens ): meta.add_new_req( @@ -1363,13 +1363,9 @@ class MoRIIOConnectorWorker: ) # Agent. self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) - self.moriio_wrapper.set_moriio_engine(self.moriio_engine) - self.moriio_wrapper.set_backend_type(BackendType.RDMA) - self.moriio_wrapper.notify_port = self.moriio_config.notify_port - self.local_kv_cache_metadata = [] self.local_kv_cache_size = [] self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() @@ -1450,7 +1446,6 @@ class MoRIIOConnectorWorker: self.use_mla = self.model_config.use_mla self.built_session = False self.built_write_session: defaultdict[str, list] = defaultdict(list) - self._write_session_lock = threading.Lock() self.debug_cache = [] backend = get_attn_backend( self.model_config.get_head_size(), @@ -1467,15 +1462,6 @@ class MoRIIOConnectorWorker: # self._use_flashinfer = attn_backend == _Backend.FLASHINFER logger.debug("Detected attention backend %s", self.backend_name) - self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} - - ####write worker### - # self._write_task_q: Queue[WriteTask] = Queue() - # self._write_worker_started = False - # self._write_worker_lock = threading.Lock() - # self._deferred_tasks: list[WriteTask] = [] - # ####write worker### - def schedule_write_blocks( self, request_id: str, @@ -1539,8 +1525,6 @@ class MoRIIOConnectorWorker: return self.built_write_session[remote_engine_id] def _ping(self, zmq_context): - PING_INTERVAL = 5 - MAX_RETRIES = 100000 http_request_address = f"http://{self.request_address}/v1/completions" role = "P" if self.is_producer else "D" @@ -1585,13 +1569,13 @@ class MoRIIOConnectorWorker: retry_count += 1 finally: - if retry_count >= MAX_RETRIES: + if retry_count >= MoRIIOConstants.MAX_RETRIES: logger.error( - f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop." + f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop." ) break - time.sleep(PING_INTERVAL) + time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 def handle_proxy_request(self): @@ -1603,7 +1587,6 @@ class MoRIIOConnectorWorker: socks = dict(self.poller.poll()) logger.debug(f"handle_proxy_request: {socks = }") - # TODO: inkcherry , check here? if self.metadata_socket not in socks: continue @@ -1665,8 +1648,6 @@ class MoRIIOConnectorWorker: "MoRIIO handshake listener received done recv for req %s", req_id.decode(), ) - else: - pass def _moriio_handshake( self, @@ -1709,9 +1690,7 @@ class MoRIIOConnectorWorker: ) self.moriio_wrapper.remote_engine_ip = host - remote_agent_name = self.moriio_wrapper.register_remote_engine( - metadata.agent_metadata - ) + remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key logger.info( @@ -1873,8 +1852,6 @@ class MoRIIOConnectorWorker: if layer_name not in self.layer_name_to_local_kv_cache_metadata: self.layer_name_to_local_kv_cache_metadata[layer_name] = [] - # for cache in cache_list: - # moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache) moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache) self.layer_name_to_local_kv_cache_metadata[layer_name].append( moriio_mem_metadata @@ -1942,7 +1919,7 @@ class MoRIIOConnectorWorker: to track which workers are done. """ - done_sending, done_recving = set(), set() + done_sending = set() if self.is_producer: done_sending = self.moriio_wrapper.pop_finished_req_ids() @@ -1991,7 +1968,6 @@ class MoRIIOConnectorWorker: remote_engine_id = None for req_id, meta in metadata.reqs_to_save.items(): - remote_engine_id = meta.remote_engine_id # we only need to check if dp0 in rank remote_engine_id = ( str(meta.remote_host) + ":" + str(meta.remote_handshake_port) @@ -2030,13 +2006,15 @@ class MoRIIOConnectorWorker: break else: break - + + + def get_engine_name_with_dp(self, engine_name,dp_rank): + return f"{engine_name}_dp{dp_rank}" def start_load_kv(self, metadata: MoRIIOConnectorMetadata): """ Start loading by triggering non-blocking moriio_xfer. We check for these trnxs to complete in each step(). """ - # print("start load kv") if self.is_producer: self.moriio_wrapper.async_wait_reqid() return @@ -2051,7 +2029,7 @@ class MoRIIOConnectorWorker: str(meta.remote_host) + ":" + str(meta.remote_handshake_port) ) meta.remote_engine_id = remote_engine_id - dp0_remote_engine_id = f"{remote_engine_id}_dp0" + dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0) if dp0_remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -2216,7 +2194,6 @@ class MoRIIOConnectorWorker: blknum, blksize, hs = self.kv_cache_shape hn = 1 block_stride = stride[0] - ktov_stride = None else: _, blknum, blksize, hn, hs = self.kv_cache_shape ktov_stride = stride[0] @@ -2260,21 +2237,20 @@ class MoRIIOConnectorWorker: return # we only test TP<->TP in read mode # assert self.dp_rank>0, "only test TP<->TP in read mode" - dst_engine_id += "_dp0" - sessions = self._get_built_session(dst_engine_id) + dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0) + sessions = self._get_built_session(dp0_engine_id) first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] offs = self._compute_block_transfer_offsets( first_layer, local_block_ids, remote_block_ids ) - a, b, c = offs[0], offs[1], offs[2] for layer_name in self.layer_name_to_local_kv_cache_metadata.keys(): sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( layer_name ) transfer_status = self.moriio_wrapper.read_remote_data( - c, a, b, sessions[sess_idx] + offs[0], offs[1], offs[2], sessions[sess_idx] ) self._recving_transfers[request_id].append(transfer_status) From 64694c3e7693c6f09b63e80f22f5834bd5afeda7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 04:50:58 +0000 Subject: [PATCH 10/62] refine Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 603efc0507bed..c04d6055fd040 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1690,8 +1690,9 @@ class MoRIIOConnectorWorker: ) self.moriio_wrapper.remote_engine_ip = host - - remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key + remote_agent_name=self.moriio_wrapper.register_remote_engine( + metadata.agent_metadata + ) logger.info( f"MoRIIO handshake: registered remote agent " @@ -1751,8 +1752,7 @@ class MoRIIOConnectorWorker: # In dp(prefill)<->dp(decode) communication, we require an all-to-all handshake. for cur_dp_rank in range(remote_dp_size): - dp_engine_id = f"{remote_engine_id}_dp{cur_dp_rank}" - + dp_engine_id = self.get_engine_name_with_dp(remote_engine_id, cur_dp_rank) future = self._handshake_initiation_executor.submit( self._moriio_handshake, host, port, tp_size, dp_engine_id, cur_dp_rank ) @@ -1990,8 +1990,6 @@ class MoRIIOConnectorWorker: self._write_blocks_for_req(req_id, meta, layer_name, kv_layer) while True: - if remote_engine_id is None: - break if ( self._ready_requests.empty() and remote_engine_id not in self.write_ready_flags @@ -2021,7 +2019,7 @@ class MoRIIOConnectorWorker: if self.mode == MoRIIOMode.WRITE: return - wait_handshage_readd_req = False + wait_handshake_readd_req = False remote_engine_id = None for req_id, meta in metadata.reqs_to_recv.items(): @@ -2037,7 +2035,7 @@ class MoRIIOConnectorWorker: self._background_moriio_handshake( req_id, remote_engine_id, meta ) - wait_handshage_readd_req = True + wait_handshake_readd_req = True continue @@ -2049,7 +2047,7 @@ class MoRIIOConnectorWorker: if ( self._ready_requests.empty() and not self.load_ready_flag - and wait_handshage_readd_req + and wait_handshake_readd_req ): continue elif not self._ready_requests.empty() and self.load_ready_flag: From 245b71a8916df8e6537a131a4120562be90bd873 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 07:17:34 +0000 Subject: [PATCH 11/62] refine Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 63 +++++++++---------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index c04d6055fd040..0b37f412fbfe4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -68,8 +68,6 @@ class MoRIIOConstants: OVER = b"OVER" COMPLETION_PREFIX = "cmpl" - # Default GPU count per node for standard configurations - RANK_PER_NODE = 8 PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 @@ -204,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode: def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: - return ((dp_rank) * tp_size + tp_rank) % MoRIIOConstants.RANK_PER_NODE + return ((dp_rank) * tp_size + tp_rank) @dataclass @@ -405,8 +403,8 @@ class MoRIIOWriter: return # Wait for CUDA event - #The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. - #This event is used to synchronize the kv transfer and computation tasks. + # The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. + # This event is used to synchronize the kv transfer and computation tasks. task.event.synchronize() # Update engine ID with DP rank @@ -546,8 +544,7 @@ class MoRIIOWrapper: self.done_write_cache_req_ids = [] self.notify_thread = None self.sock = None - self.sessions = [] - self.kv_caches = None + self.sessions: list["IOEngine.Session"] = [] self.paths = {} def set_moriio_engine(self, moriio_engine): @@ -730,7 +727,7 @@ class MoRIIOWrapper: path = make_zmq_path("tcp", remote_ip, str(remote_port)) if path not in self.paths: - ctx = zmq.Context() + ctx = zmq.Context.instance() sock = make_zmq_socket( ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False ) @@ -1033,7 +1030,7 @@ class MoRIIOConnectorScheduler: ): path = make_zmq_path("tcp", host, port) if path not in self.paths: - ctx = zmq.Context() + ctx = zmq.Context.instance() sock = make_zmq_socket( ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False ) @@ -1268,7 +1265,7 @@ class MoRIIOConnectorWorker: self.mode = get_moriio_mode() logger.info("Initializing MoRIIO worker %s", engine_id) - # for debug + logging.getLogger("aiter").disabled = True # Config. @@ -1377,7 +1374,7 @@ class MoRIIOConnectorWorker: ) self.slot_size_bytes = 0 - self.load_ready_flag = False + self.load_ready_flag = {} self.write_ready_flags = {} self.kv_cache_shape = None self.block_shape = None @@ -1569,9 +1566,9 @@ class MoRIIOConnectorWorker: retry_count += 1 finally: - if retry_count >= MoRIIOConstants.MAX_RETRIES: + if retry_count >= MoRIIOConstants.MAX_PING_RETRIES: logger.error( - f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop." + f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop." ) break @@ -1590,12 +1587,19 @@ class MoRIIOConnectorWorker: if self.metadata_socket not in socks: continue + def close(self): + if hasattr(self, '_handshake_initiation_executor'): + self._handshake_initiation_executor.shutdown(wait=False) + + if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t: + self._moriio_handshake_listener_t.join(timeout=0) + + if hasattr(self, 'zmq_context') and self.zmq_context: + self.zmq_context.destroy(linger=0) + self.zmq_context = None def __del__(self): - """Cleanup background threads on destruction.""" - self._handshake_initiation_executor.shutdown(wait=False) - if self._moriio_handshake_listener_t: - self._moriio_handshake_listener_t.join(timeout=0) + self.close() @staticmethod def _moriio_handshake_listener( @@ -1744,7 +1748,7 @@ class MoRIIOConnectorWorker: def request_ready(_f: Future[Any], entry=(req_id, meta)): logger.info("MoRIIO handshake done for request %s", req_id) self._ready_requests.put(entry) - self.load_ready_flag = True + self.load_ready_flag [remote_engine_id] = True self.write_ready_flags[remote_engine_id] = True fut_list = [] @@ -1826,15 +1830,6 @@ class MoRIIOConnectorWorker: kv_caches_base_addr = [] caches_data = [] - # Note(tms): I modified this from the original region setup code. - # K and V are now in different regions. Advantage is that we can - # elegantly support MLA and any cases where the K and V tensors - # are non-contiguous (it's not locally guaranteed that they will be) - # Disadvantage is that the encoded MoRIIOAgentMetadata is now larger - # (roughly 8KB vs 5KB). - # Conversely for FlashInfer, K and V are transferred in the same tensor - # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in kv_caches.values(): cache_list = ( [cache_or_caches] @@ -2043,18 +2038,20 @@ class MoRIIOConnectorWorker: self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. - while True: # TODO + while True: if ( self._ready_requests.empty() - and not self.load_ready_flag + and remote_engine_id not in self.load_ready_flag and wait_handshake_readd_req ): continue - elif not self._ready_requests.empty() and self.load_ready_flag: + elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag: self._read_blocks_for_req(*self._ready_requests.get_nowait()) break else: break + + self._reqs_to_send.update(metadata.reqs_to_send) @@ -2233,8 +2230,8 @@ class MoRIIOConnectorWorker: ) -> None: if self.mode == MoRIIOMode.WRITE: return - # we only test TP<->TP in read mode - # assert self.dp_rank>0, "only test TP<->TP in read mode" + + dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0) sessions = self._get_built_session(dp0_engine_id) @@ -2248,7 +2245,7 @@ class MoRIIOConnectorWorker: layer_name ) transfer_status = self.moriio_wrapper.read_remote_data( - offs[0], offs[1], offs[2], sessions[sess_idx] + offs[2], offs[0], offs[1], sessions[sess_idx] ) self._recving_transfers[request_id].append(transfer_status) From 4f592ae696dbf5553f7d5d2e79807cc4ac740857 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 07:22:02 +0000 Subject: [PATCH 12/62] format Signed-off-by: inkcherry --- .../moriio_integration/toy_proxy_server.py | 267 +++++++++++------- 1 file changed, 161 insertions(+), 106 deletions(-) diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py index e375b84bd4053..3091d98366c0a 100644 --- a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py @@ -1,35 +1,35 @@ -import argparse +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import copy import logging import os +import re import socket +import threading import uuid + import msgpack import zmq -import copy -import threading from quart import Quart, make_response, request -import re -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse -from typing import Dict,List -import asyncio + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) import aiohttp + prefill_instances = [] -decode_instances = [] +decode_instances = [] request_nums = 0 app = Quart(__name__) yield_chunk = set() -IP_PORT_PATTERN = re.compile(r'//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)') - -from itertools import count +IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)") +TRANSFER_TYPE = None + -TRANSFER_TYPE=None def _append_whole_dict_unique(target_list, data_dict): new_filtered = {k: v for k, v in data_dict.items() if k != "index"} for existed in target_list: @@ -44,39 +44,42 @@ def _append_whole_dict_unique(target_list, data_dict): if TRANSFER_TYPE is None: TRANSFER_TYPE = transfer_mode logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) - elif TRANSFER_TYPE != transfer_mode: + elif transfer_mode != TRANSFER_TYPE: raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") - + return True + + _list_lock = threading.RLock() + def _listen_for_register(hostname, port): context = zmq.Context() router_socket = context.socket(zmq.ROUTER) router_socket.bind(f"tcp://{hostname}:{port}") poller = zmq.Poller() - poller.register(router_socket,zmq.POLLIN) + poller.register(router_socket, zmq.POLLIN) global prefill_instances global decode_instances while True: socks = dict(poller.poll()) if router_socket in socks: - - remote_addr,msg = router_socket.recv_multipart() + remote_addr, msg = router_socket.recv_multipart() data = msgpack.loads(msg) - if data['type'] == "HELLO": + if data["type"] == "HELLO": pass - elif data['type'] == "register" and data['role'] == "P": - if data['request_address'] not in prefill_instances: + elif data["type"] == "register" and data["role"] == "P": + if data["request_address"] not in prefill_instances: with _list_lock: _append_whole_dict_unique(prefill_instances, data) - elif data["type"] == "register" and data['role'] == "D": - if data['request_address'] not in decode_instances: + elif data["type"] == "register" and data["role"] == "D": + if data["request_address"] not in decode_instances: with _list_lock: _append_whole_dict_unique(decode_instances, data) + def start_service_discovery(hostname, port): if not hostname: hostname = socket.gethostname() @@ -84,147 +87,198 @@ def start_service_discovery(hostname, port): raise ValueError("Port cannot be 0") _listener_thread = threading.Thread( - target = _listen_for_register,args = (hostname, port),daemon=True + target=_listen_for_register, args=(hostname, port), daemon=True ) _listener_thread.start() return _listener_thread -async def send_request_to_prefill(endpoint,req_data,request_id,p_endpoint,pip,pports,selected_prefill_dp_rank): - req_data_copy =req_data - - - req_data_copy['kv_transfer_params'].update({ - "do_remote_decode": True, - "do_remote_prefill": False, - "remote_handshake_port": p_endpoint['handshake_port'], - "remote_notify_port":p_endpoint['notify_port'], - "remote_engine_id": None, - "remote_block_ids": None, - "remote_host":pip , - "remote_port": pports, - }) + +async def send_request_to_prefill( + endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank +): + req_data_copy = req_data + + req_data_copy["kv_transfer_params"].update( + { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_handshake_port": p_endpoint["handshake_port"], + "remote_notify_port": p_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": pip, + "remote_port": pports, + } + ) req_data_copy["stream"] = False req_data_copy["max_tokens"] = 1 if "max_completion_tokens" in req_data_copy: req_data_copy["max_completion_tokens"] = 1 if "stream_options" in req_data_copy: del req_data_copy["stream_options"] - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } if selected_prefill_dp_rank is not None: - headers['X-data-parallel-rank']=str(selected_prefill_dp_rank) - async with session.post(url=endpoint, json=req_data_copy, headers=headers) as response: + headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank) + async with session.post( + url=endpoint, json=req_data_copy, headers=headers + ) as response: if response.status == 200: return await response.json() - + else: - raise RuntimeError("send_request_to_prefill response.status != 200,response.statuus = ",response.status) + raise RuntimeError( + "send_request_to_prefill response.status != 200,response.statuus = ", + response.status, + ) + + async def start_decode_request(endpoint, req_data, request_id): - session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } response = await session.post(url=endpoint, json=req_data, headers=headers) return session, response + async def stream_decode_response(session, response, request_id): try: if response.status == 200: async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes else: - raise RuntimeError(f"decode response.status != 200, status = {response.status}") + raise RuntimeError( + f"decode response.status != 200, status = {response.status}" + ) finally: await session.close() -async def send_request_to_decode(endpoint,req_data,request_id): - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + + +async def send_request_to_decode(endpoint, req_data, request_id): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } - async with session.post(url=endpoint, json=req_data, headers=headers) as response: + async with session.post( + url=endpoint, json=req_data, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content.iter_chunked(1024): - - yield chunk_bytes + yield chunk_bytes else: - raise RuntimeError("send_request_to_decode response.status != 200,response.statuus = ",response.status) + raise RuntimeError( + "send_request_to_decode response.status != 200,response.statuus = ", + response.status, + ) + + def example_round_robin_dp_loader(request_number, dp_size): return request_nums % dp_size + @app.route("/v1/completions", methods=["POST"]) @app.route("/v1/chat/completions", methods=["POST"]) async def handle_request(): try: - global request_nums request_nums += 1 + def extract_ip_port_fast(url): return IP_PORT_PATTERN.search(url).groups() + req_data = await request.get_json() request_id = str(uuid.uuid4()) - prefill_instance_endpoint=None - decode_instance_endpoint=None - - pid=request_nums % len(prefill_instances) - did=request_nums % len(decode_instances) + prefill_instance_endpoint = None + decode_instance_endpoint = None + + pid = request_nums % len(prefill_instances) + did = request_nums % len(decode_instances) prefill_instance_endpoint = prefill_instances[pid] decode_instance_endpoint = decode_instances[did] - - - selected_prefill_dp_rank=None - if prefill_instance_endpoint['dp_size']>1: - selected_prefill_dp_rank=example_round_robin_dp_loader(request_nums//len(prefill_instance_endpoint),prefill_instance_endpoint['dp_size']) - - dip,dport= extract_ip_port_fast(decode_instance_endpoint['request_address']) - ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) - + + selected_prefill_dp_rank = None + if prefill_instance_endpoint["dp_size"] > 1: + selected_prefill_dp_rank = example_round_robin_dp_loader( + request_nums // len(prefill_instance_endpoint), + prefill_instance_endpoint["dp_size"], + ) + + dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"]) + ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"]) + req_data_to_prefill = copy.deepcopy(req_data) - req_data_to_prefill['kv_transfer_params']={} - req_data['kv_transfer_params']={} - req_data_to_prefill['kv_transfer_params']['remote_dp_size']=decode_instance_endpoint['dp_size'] - req_data_to_prefill['kv_transfer_params']['remote_tp_size']=decode_instance_endpoint['tp_size'] - - - - send_prefill_task = asyncio.create_task(send_request_to_prefill(prefill_instance_endpoint['request_address'],req_data_to_prefill,request_id,decode_instance_endpoint,dip,dport,selected_prefill_dp_rank)) - ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) - - - req_data['max_tokens'] -= 1 - - req_data['kv_transfer_params'] = { + req_data_to_prefill["kv_transfer_params"] = {} + req_data["kv_transfer_params"] = {} + req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = ( + decode_instance_endpoint["dp_size"] + ) + req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = ( + decode_instance_endpoint["tp_size"] + ) + + send_prefill_task = asyncio.create_task( + send_request_to_prefill( + prefill_instance_endpoint["request_address"], + req_data_to_prefill, + request_id, + decode_instance_endpoint, + dip, + dport, + selected_prefill_dp_rank, + ) + ) + ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"]) + + req_data["max_tokens"] -= 1 + + req_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, - "remote_handshake_port": prefill_instance_endpoint['handshake_port'], - "remote_notify_port":prefill_instance_endpoint['notify_port'], - "remote_engine_id":None, - "remote_block_ids":None, - "remote_host":ip , + "remote_handshake_port": prefill_instance_endpoint["handshake_port"], + "remote_notify_port": prefill_instance_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": ip, "remote_port": port, } - if TRANSFER_TYPE =="READ": - #In read mode, prefill and decode are executed serially. - prefill_response=await send_prefill_task - req_data['kv_transfer_params']['remote_engine_id']=prefill_response['kv_transfer_params']['remote_engine_id'] - req_data['kv_transfer_params']['remote_block_ids']=prefill_response['kv_transfer_params']['remote_block_ids'] - - req_data['kv_transfer_params']['remote_dp_size'] = prefill_instance_endpoint['dp_size'] - req_data['kv_transfer_params']['remote_tp_size'] = prefill_instance_endpoint['tp_size'] - + if TRANSFER_TYPE == "READ": + # In read mode, prefill and decode are executed serially. + prefill_response = await send_prefill_task + req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[ + "kv_transfer_params" + ]["remote_engine_id"] + req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[ + "kv_transfer_params" + ]["remote_block_ids"] + + req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[ + "dp_size" + ] + req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[ + "tp_size" + ] + if selected_prefill_dp_rank is not None: - req_data['kv_transfer_params']['remote_dp_rank'] = selected_prefill_dp_rank + req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank decode_request_task = asyncio.create_task( - start_decode_request(decode_instance_endpoint['request_address'], req_data, request_id) + start_decode_request( + decode_instance_endpoint["request_address"], req_data, request_id + ) ) - session, decode_response = await decode_request_task stream_generator = stream_decode_response(session, decode_response, request_id) @@ -234,11 +288,12 @@ async def handle_request(): print(e) pass -if __name__ == '__main__': + +if __name__ == "__main__": t = start_service_discovery("0.0.0.0", 36367) - app.debug = True - app.config['BODY_TIMEOUT'] = 360000 - app.config['RESPONSE_TIMEOUT'] = 360000 + app.debug = True + app.config["BODY_TIMEOUT"] = 360000 + app.config["RESPONSE_TIMEOUT"] = 360000 app.run(host="0.0.0.0", port=10001) - t.join() \ No newline at end of file + t.join() From b60ee86585d34ef97c66dd6864b036d963f22fc6 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 07:31:29 +0000 Subject: [PATCH 13/62] format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 109 ++++++++---------- 1 file changed, 47 insertions(+), 62 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 0b37f412fbfe4..cc9cc28a9980b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1,19 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import logging import math import os import pickle import queue import threading import time -import msgpack -import msgspec -import numpy as np -import torch -import zmq -import logging - from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor @@ -21,6 +15,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional from weakref import ref as weakref_ref +import msgpack +import msgspec +import numpy as np +import torch +import zmq + from vllm import envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend @@ -68,10 +68,10 @@ class MoRIIOConstants: OVER = b"OVER" COMPLETION_PREFIX = "cmpl" - PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 + try: import mori from mori.io import ( @@ -202,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode: def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: - return ((dp_rank) * tp_size + tp_rank) + return (dp_rank) * tp_size + tp_rank @dataclass @@ -316,7 +316,6 @@ class MoRIIOWriter: def _write_worker_loop(self) -> None: """Main loop for the write worker thread.""" - logger.info("Write worker loop started") while True: # Process deferred tasks first @@ -408,7 +407,9 @@ class MoRIIOWriter: task.event.synchronize() # Update engine ID with DP rank - task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank) + task.dst_engine_id = self.worker.get_engine_name_with_dp( + task.dst_engine_id, request_info.decode_dp_rank + ) # Get or create sessions sessions = self.worker._get_built_session(task.dst_engine_id) @@ -544,7 +545,7 @@ class MoRIIOWrapper: self.done_write_cache_req_ids = [] self.notify_thread = None self.sock = None - self.sessions: list["IOEngine.Session"] = [] + self.sessions: list[IOEngine.Session] = [] self.paths = {} def set_moriio_engine(self, moriio_engine): @@ -968,7 +969,7 @@ class MoRIIOConnectorScheduler: "handshake_port" ] ) - logger.info(f"==========> Initializing MoRIIO Scheduler {engine_id = }") + logger.info(f"Initializing MoRIIO Scheduler {engine_id = }") self.side_notify_port = ( self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] @@ -1149,7 +1150,7 @@ class MoRIIOConnectorScheduler: self._reqs_need_pending_save[req_id] = (req, updated_blocks) if ( len(self._reqs_need_pending_save[req_id][1]) - * self.block_size + * self.block_size >= req.num_prompt_tokens ): meta.add_new_req( @@ -1171,7 +1172,7 @@ class MoRIIOConnectorScheduler: for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - if req.num_prompt_tokens > len(block_ids)*self.block_size: + if req.num_prompt_tokens > len(block_ids) * self.block_size: # not last chunk prefill self._reqs_need_pending_save[req_id] = (req, block_ids) continue @@ -1265,7 +1266,7 @@ class MoRIIOConnectorWorker: self.mode = get_moriio_mode() logger.info("Initializing MoRIIO worker %s", engine_id) - + logging.getLogger("aiter").disabled = True # Config. @@ -1283,11 +1284,6 @@ class MoRIIOConnectorWorker: self.tp_rank = self.moriio_config.tp_rank self.dp_rank = self.moriio_config.dp_rank - logger.info( - f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }" - f",{self.is_producer= }" - ) - self.local_ip = self.moriio_config.local_ip self.local_kv_port = self.moriio_config.local_kv_port self.proxy_ip = self.moriio_config.proxy_ip @@ -1340,8 +1336,8 @@ class MoRIIOConnectorWorker: ), ) - logger.info( - "build IOEngine %s:%s", + logger.debug( + "build MORI IOEngine %s:%s", self.moriio_config.local_ip, self.moriio_config.local_kv_port, ) @@ -1390,8 +1386,6 @@ class MoRIIOConnectorWorker: self.moriio_config.handshake_port + get_port_offset(self.dp_rank, self.tp_rank) ) - logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }") - logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}") self.engine_id: EngineId = engine_id self.world_size = get_tensor_model_parallel_world_size() @@ -1522,7 +1516,6 @@ class MoRIIOConnectorWorker: return self.built_write_session[remote_engine_id] def _ping(self, zmq_context): - http_request_address = f"http://{self.request_address}/v1/completions" role = "P" if self.is_producer else "D" @@ -1586,15 +1579,18 @@ class MoRIIOConnectorWorker: if self.metadata_socket not in socks: continue - + def close(self): - if hasattr(self, '_handshake_initiation_executor'): + if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) - - if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t: + + if ( + hasattr(self, "_moriio_handshake_listener_t") + and self._moriio_handshake_listener_t + ): self._moriio_handshake_listener_t.join(timeout=0) - - if hasattr(self, 'zmq_context') and self.zmq_context: + + if hasattr(self, "zmq_context") and self.zmq_context: self.zmq_context.destroy(linger=0) self.zmq_context = None @@ -1621,12 +1617,9 @@ class MoRIIOConnectorWorker: # Listen for new requests for metadata. host = "*" - logger.info( - f"======> mori handshake starting listening on baseport: {base_port}" - ) path = make_zmq_path("tcp", host, base_port) - logger.info(f"======> mori handshake starting listening on path: {path}") + logger.debug(f" mori handshake starting listening on path: {path}") with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() @@ -1636,20 +1629,20 @@ class MoRIIOConnectorWorker: msg != MoRIIOConstants.GET_META_MSG and msg != MoRIIOConstants.POP_DONE_RECV ): - logger.error("Connection listener got unexpected message %s", msg) + logger.error("Connection listener got unexpected message") raise HandshakeError("handshake failed, unexpected msg type") elif msg == MoRIIOConstants.GET_META_MSG: sock.send_multipart( (identity, b"", encoded_data) ) # send local mori io engine meta data - logger.info("MoRIIO handshake listener sent metadata to %s") + logger.debug("MoRIIO handshake listener sent metadata") # now we send tensor meta data for each block buf = pickle.dumps(layer_name_to_local_kv_cache_metadata) sock.send_multipart((identity, b"", buf)) elif msg == MoRIIOConstants.POP_DONE_RECV: _, req_id = sock.recv_multipart() - logger.info( - "MoRIIO handshake listener received done recv for req %s", + logger.debug( + "MoRIIO handshake listener received done recv for req", req_id.decode(), ) @@ -1671,10 +1664,7 @@ class MoRIIOConnectorWorker: port_offset = get_port_offset(remote_dp_rank, self.tp_rank) path = make_zmq_path("tcp", host, port + port_offset) - logger.info( - "handshake Querying metadata on path: %s at remote rank %s", - path, - ) + logger.debug(f"handshake Querying metadata on path:{path}") # Send query for the request. with zmq_ctx(zmq.DEALER, path) as sock: @@ -1694,7 +1684,7 @@ class MoRIIOConnectorWorker: ) self.moriio_wrapper.remote_engine_ip = host - remote_agent_name=self.moriio_wrapper.register_remote_engine( + remote_agent_name = self.moriio_wrapper.register_remote_engine( metadata.agent_metadata ) @@ -1748,7 +1738,7 @@ class MoRIIOConnectorWorker: def request_ready(_f: Future[Any], entry=(req_id, meta)): logger.info("MoRIIO handshake done for request %s", req_id) self._ready_requests.put(entry) - self.load_ready_flag [remote_engine_id] = True + self.load_ready_flag[remote_engine_id] = True self.write_ready_flags[remote_engine_id] = True fut_list = [] @@ -1999,10 +1989,10 @@ class MoRIIOConnectorWorker: break else: break - - - def get_engine_name_with_dp(self, engine_name,dp_rank): + + def get_engine_name_with_dp(self, engine_name, dp_rank): return f"{engine_name}_dp{dp_rank}" + def start_load_kv(self, metadata: MoRIIOConnectorMetadata): """ Start loading by triggering non-blocking moriio_xfer. @@ -2022,7 +2012,7 @@ class MoRIIOConnectorWorker: str(meta.remote_host) + ":" + str(meta.remote_handshake_port) ) meta.remote_engine_id = remote_engine_id - dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0) + dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id, 0) if dp0_remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -2038,20 +2028,21 @@ class MoRIIOConnectorWorker: self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. - while True: + while True: if ( self._ready_requests.empty() and remote_engine_id not in self.load_ready_flag and wait_handshake_readd_req ): continue - elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag: + elif ( + not self._ready_requests.empty() + and remote_engine_id in self.load_ready_flag + ): self._read_blocks_for_req(*self._ready_requests.get_nowait()) break else: break - - self._reqs_to_send.update(metadata.reqs_to_send) @@ -2089,11 +2080,6 @@ class MoRIIOConnectorWorker: return True return False - def _is_first_layer(self, layer_name): - if layer_name == list(self.kv_caches.keys())[0]: - return True - return False - def merge_contiguous_blocks( self, offsets_local: list[int], @@ -2230,9 +2216,8 @@ class MoRIIOConnectorWorker: ) -> None: if self.mode == MoRIIOMode.WRITE: return - - - dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0) + + dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0) sessions = self._get_built_session(dp0_engine_id) first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] From 4034937733e9226d87cd05ac43dad35189be2ea2 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 10:14:34 +0000 Subject: [PATCH 14/62] remove port Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index cc9cc28a9980b..39d06b6e0bced 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -210,7 +210,6 @@ class MoRIIOConfig: local_ip: str local_kv_port: int proxy_ip: str - proxy_port: int local_ping_port: int proxy_ping_port: int http_port: int @@ -245,7 +244,6 @@ class MoRIIOConfig: local_ip=get_ip(), local_kv_port=base_kv_port + port_offset, proxy_ip=extra_config["proxy_ip"], - proxy_port=int(extra_config["proxy_port"]), local_ping_port=base_ping_port + port_offset, proxy_ping_port=int(extra_config["proxy_ping_port"]), http_port=int(extra_config["http_port"]), @@ -843,6 +841,8 @@ class MoRIIOConnector(KVConnectorBase_V1): super().__init__(vllm_config, role) assert vllm_config.kv_transfer_config is not None # assert vllm_config.kv_transfer_config.engine_id is not None + self._set_port_defaults(vllm_config) + self.engine_id = ( str(get_ip()) + ":" @@ -869,6 +869,22 @@ class MoRIIOConnector(KVConnectorBase_V1): # Scheduler Side Methods ############################################################ + def _set_port_defaults(self, vllm_config: VllmConfig): + kv_transfer_config = vllm_config.kv_transfer_config + extra_config = kv_transfer_config.kv_connector_extra_config + + if "handshake_port" not in extra_config or not extra_config["handshake_port"]: + extra_config["handshake_port"] = "6301" + + if "notify_port" not in extra_config or not extra_config["notify_port"]: + extra_config["notify_port"] = "61005" + + if "local_ping_port" not in extra_config or not extra_config["local_ping_port"]: + extra_config["local_ping_port"] = "7583" + + if not kv_transfer_config.kv_port: + kv_transfer_config.kv_port = "7305" + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -1287,7 +1303,6 @@ class MoRIIOConnectorWorker: self.local_ip = self.moriio_config.local_ip self.local_kv_port = self.moriio_config.local_kv_port self.proxy_ip = self.moriio_config.proxy_ip - self.proxy_port = self.moriio_config.proxy_port self.local_ping_port = self.moriio_config.local_ping_port self.proxy_ping_port = self.moriio_config.proxy_ping_port self.http_port = self.moriio_config.http_port @@ -1352,7 +1367,7 @@ class MoRIIOConnectorWorker: f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},role = {'producer' if self.is_producer else 'consumer'}" ) logger.debug( - f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.proxy_port = },{self.local_ping_port = },{self.proxy_ping_port = }" + f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.local_ping_port = },{self.proxy_ping_port = }" ) # Agent. self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) From bba4c89ca4e7201a7aaa26e9d99e596b4a7be83f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 10:25:06 +0000 Subject: [PATCH 15/62] format Signed-off-by: inkcherry --- .../moriio_integration/toy_proxy_server.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py index 3091d98366c0a..27c450ee7b25c 100644 --- a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py @@ -9,15 +9,13 @@ import socket import threading import uuid +import aiohttp import msgpack import zmq from quart import Quart, make_response, request logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - -import aiohttp - prefill_instances = [] decode_instances = [] request_nums = 0 @@ -69,15 +67,21 @@ def _listen_for_register(hostname, port): data = msgpack.loads(msg) if data["type"] == "HELLO": pass - elif data["type"] == "register" and data["role"] == "P": - if data["request_address"] not in prefill_instances: - with _list_lock: - _append_whole_dict_unique(prefill_instances, data) + elif ( + data["type"] == "register" + and data["role"] == "P" + and data["request_address"] not in prefill_instances + ): + with _list_lock: + _append_whole_dict_unique(prefill_instances, data) - elif data["type"] == "register" and data["role"] == "D": - if data["request_address"] not in decode_instances: - with _list_lock: - _append_whole_dict_unique(decode_instances, data) + elif ( + data["type"] == "register" + and data["role"] == "D" + and data["request_address"] not in decode_instances + ): + with _list_lock: + _append_whole_dict_unique(decode_instances, data) def start_service_discovery(hostname, port): @@ -133,7 +137,7 @@ async def send_request_to_prefill( else: raise RuntimeError( - "send_request_to_prefill response.status != 200,response.statuus = ", + "send_request_to_prefill response.status != 200response.status = ", response.status, ) From 08cd2efbb69c5cfe03cafddb99ba8dae18520ba0 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 10:46:37 +0000 Subject: [PATCH 16/62] refine Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 39d06b6e0bced..22ba1655eabd6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -694,9 +694,9 @@ class MoRIIOWrapper: self._handle_completion_message(msg_str) handled = True except UnicodeDecodeError: - logger.warning(f"Received non-UTF8 message: {msg}") + logger.warning(f"Received non-UTF8 message: {msg_str}") if not handled: - raise MoRIIOError(f"Unhandled message format: {msg}") + raise MoRIIOError(f"Unhandled message format: {msg_str}") def _handle_structured_message(self, data: dict): req_id = data["req_id"] @@ -784,7 +784,7 @@ class ReqMeta: remote_host: str remote_port: int remote_handshake_port: int - remote_notify_port: int + remote_notify_port: int | None remote_engine_id: str tp_size: int remote_dp_size: int @@ -1011,7 +1011,7 @@ class MoRIIOConnectorScheduler: self._reqs_need_send: dict[ReqId, float] = {} self.sock = None self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" - self.paths = {} + self.paths: dict[str, zmq.Socket] = {} def get_num_new_matched_tokens( self, @@ -1043,7 +1043,7 @@ class MoRIIOConnectorScheduler: return len(request.prompt_token_ids) - 1 - num_computed_tokens, False def send_notify_block( - self, req_id: str, block_notify_list: list[int] = None, host=None, port=None + self, req_id: str, block_notify_list: list[int] , host=None, port=None ): path = make_zmq_path("tcp", host, port) if path not in self.paths: @@ -1374,25 +1374,24 @@ class MoRIIOConnectorWorker: self.moriio_wrapper.set_moriio_engine(self.moriio_engine) self.moriio_wrapper.set_backend_type(BackendType.RDMA) self.moriio_wrapper.notify_port = self.moriio_config.notify_port - self.local_kv_cache_metadata = [] - self.local_kv_cache_size = [] - self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() + self.local_kv_cache_metadata: list[bytes] = [] + self.local_kv_cache_size: list[int] = [] + self.layer_name_to_local_kv_cache_metadata: dict[str, list[bytes]] = {} - self.remote_kv_cache_metadata = [] - self.remote_kv_cache_size = [] + self.remote_kv_cache_metadata: list[bytes] = [] + self.remote_kv_cache_size: list[int] = [] self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( dict() ) self.slot_size_bytes = 0 - self.load_ready_flag = {} - self.write_ready_flags = {} + self.load_ready_flag: dict[str, bool] = {} + self.write_ready_flags: dict[str, bool] = {} self.kv_cache_shape = None self.block_shape = None self.kv_element_size = 0 - self.done_sending_reqs = [] - self.done_send_threads = [] + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -1452,7 +1451,6 @@ class MoRIIOConnectorWorker: self.use_mla = self.model_config.use_mla self.built_session = False self.built_write_session: defaultdict[str, list] = defaultdict(list) - self.debug_cache = [] backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, From a0d74ebf7f357b6ae281f921e049b70ef324af89 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 11:05:53 +0000 Subject: [PATCH 17/62] fix format error Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 22ba1655eabd6..d3343feb88dd0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -802,8 +802,8 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): return_str += f"{req_id = },{req_meta.local_block_ids = },{req_meta.remote_block_ids = },{req_meta.remote_host = },{req_meta.remote_port = },{req_meta.remote_engine_id = },{req_meta.tp_size = }" return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," - for req_id, req_meta in self.reqs_to_send.items(): - return_str += f"{req_id = },{req_meta = }" + for req_id, expiry in self.reqs_to_send.items(): + return_str += f"{req_id = },{expiry = }" return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str}," return return_str @@ -929,10 +929,11 @@ class MoRIIOConnector(KVConnectorBase_V1): return self.connector_worker.get_finished() def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + assert self.connector_worker is not None if self.mode == MoRIIOMode.WRITE: if get_role() == ROLE.CONSUMER: self.connector_worker.moriio_wrapper.async_wait_reqid() - assert self.connector_worker is not None + assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) @@ -949,6 +950,9 @@ class MoRIIOConnector(KVConnectorBase_V1): # Only producer/prefill saves KV Cache if get_role() == ROLE.CONSUMER: return + assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), ( + "Connector metadata not initialized yet" + ) self.connector_worker.save_kv_layer( self._connector_metadata, layer_name, kv_layer, attn_metadata, **kwargs ) @@ -1043,7 +1047,7 @@ class MoRIIOConnectorScheduler: return len(request.prompt_token_ids) - 1 - num_computed_tokens, False def send_notify_block( - self, req_id: str, block_notify_list: list[int] , host=None, port=None + self, req_id: str, block_notify_list: list[int], host=None, port=None ): path = make_zmq_path("tcp", host, port) if path not in self.paths: @@ -1137,6 +1141,12 @@ class MoRIIOConnectorScheduler: for new_req in scheduler_output.scheduled_new_reqs: red_id = new_req.req_id local_block_ids = list(new_req.block_ids) + assert new_req.sampling_params is not None, ( + f"sampling_params is None for req {new_req.req_id}" + ) + assert hasattr(new_req.sampling_params, "extra_args"), ( + f"sampling_params missing extra_args for req {new_req.req_id}" + ) kv_transfer_params = new_req.sampling_params.extra_args[ "kv_transfer_params" ] @@ -1391,8 +1401,6 @@ class MoRIIOConnectorWorker: self.block_shape = None self.kv_element_size = 0 - - # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) From 9b90f5ddb231ff2cf29d3b1958661e5d451236bc Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 11:30:48 +0000 Subject: [PATCH 18/62] update Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index d3343feb88dd0..a3737700d20f3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -784,7 +784,7 @@ class ReqMeta: remote_host: str remote_port: int remote_handshake_port: int - remote_notify_port: int | None + remote_notify_port: int remote_engine_id: str tp_size: int remote_dp_size: int @@ -821,7 +821,7 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], remote_handshake_port=kv_transfer_params["remote_handshake_port"], - remote_notify_port=kv_transfer_params.get("remote_notify_port"), + remote_notify_port=kv_transfer_params["remote_notify_port"], tp_size=kv_transfer_params.get("tp_size", 1), remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), ) @@ -950,6 +950,10 @@ class MoRIIOConnector(KVConnectorBase_V1): # Only producer/prefill saves KV Cache if get_role() == ROLE.CONSUMER: return + assert self.connector_worker is not None, ( + "save_kv_layer called on scheduler role" + ) + assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), ( "Connector metadata not initialized yet" ) @@ -1674,7 +1678,7 @@ class MoRIIOConnectorWorker: remote_tp_size: int, expected_engine_id: str, remote_dp_rank: int = 0, - ) -> dict[int, str]: + ) -> set[str]: """Do a MoRIIO handshake with a remote instance.""" start_time = time.perf_counter() @@ -2085,6 +2089,7 @@ class MoRIIOConnectorWorker: def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer): # logger.debug(f"write block for req {req_id} to remote engine " # f"{meta.remote_engine_id}") + self.schedule_write_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -2189,6 +2194,7 @@ class MoRIIOConnectorWorker: Returns: Tuple of (local_offsets, remote_offsets, transfer_sizes) """ + assert self.kv_cache_shape is not None, "KV caches shape not initialized" is_mla = len(self.kv_cache_shape) == 3 stride = self.kv_caches[layer_name].stride() sz = self.kv_caches[layer_name].element_size() From 4c79f34e8aa73e73d7a66a0772dbcaf4420e62a0 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 11:36:31 +0000 Subject: [PATCH 19/62] fix mypy Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index a3737700d20f3..2e641e451319d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1678,7 +1678,7 @@ class MoRIIOConnectorWorker: remote_tp_size: int, expected_engine_id: str, remote_dp_rank: int = 0, - ) -> set[str]: + ) -> dict[int, str]: """Do a MoRIIO handshake with a remote instance.""" start_time = time.perf_counter() @@ -2263,7 +2263,7 @@ class MoRIIOConnectorWorker: self._recving_transfers[request_id].append(transfer_status) self._recving_transfers_callback_addr[request_id] = ( remote_host, - remote_notify_port + self.tp_rank, + str(remote_notify_port + self.tp_rank), ) From 3f7120368e2d1b7e730b9a1536cd8b3f5b1adae8 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 11:59:43 +0000 Subject: [PATCH 20/62] fix mypy and tp test pass Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 2e641e451319d..54d2802b464e0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1405,8 +1405,8 @@ class MoRIIOConnectorWorker: self.block_shape = None self.kv_element_size = 0 - # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. - self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) + # Map of engine_id -> {agent_name0, agent_name1..}. + self._remote_agents: dict[EngineId, set[str]] = {} self.side_channel_port: int = ( self.moriio_config.handshake_port @@ -1448,7 +1448,7 @@ class MoRIIOConnectorWorker: thread_name_prefix="vllm-moriio-handshake-initiator", ) self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() - self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + self._handshake_futures: dict[EngineId, Future[set[str]]] = {} # Protects _handshake_futures and _remote_agents. self._handshake_lock = threading.RLock() @@ -1678,7 +1678,7 @@ class MoRIIOConnectorWorker: remote_tp_size: int, expected_engine_id: str, remote_dp_rank: int = 0, - ) -> dict[int, str]: + ) -> set[str]: """Do a MoRIIO handshake with a remote instance.""" start_time = time.perf_counter() @@ -1777,7 +1777,7 @@ class MoRIIOConnectorWorker: ) fut_list.append(future) - def done_callback(f: Future[dict[int, str]], eid=dp_engine_id): + def done_callback(f: Future[set[str]], eid=dp_engine_id): with self._handshake_lock: self._handshake_futures.pop(eid, None) try: From f75eecde0ac109e707557500ed4c0fa5c3c27ec5 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 12:33:31 +0000 Subject: [PATCH 21/62] fix all mypy Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 54d2802b464e0..61b83c03b2328 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1144,7 +1144,7 @@ class MoRIIOConnectorScheduler: if get_role() == ROLE.CONSUMER: for new_req in scheduler_output.scheduled_new_reqs: red_id = new_req.req_id - local_block_ids = list(new_req.block_ids) + local_block_ids = list(new_req.block_ids)[0] assert new_req.sampling_params is not None, ( f"sampling_params is None for req {new_req.req_id}" ) @@ -1174,9 +1174,7 @@ class MoRIIOConnectorScheduler: block_ids = new_block_ids[0] req, existing_blocks = self._reqs_need_pending_save[req_id] - updated_blocks = list(existing_blocks) + ( - [block_ids] if isinstance(block_ids, int) else block_ids - ) + updated_blocks = list(existing_blocks) + (block_ids) self._reqs_need_pending_save[req_id] = (req, updated_blocks) if ( len(self._reqs_need_pending_save[req_id][1]) From e0885e52d91b501dc2a5f6001c36a6628e4d7eed Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 03:55:32 +0000 Subject: [PATCH 22/62] break long line Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 61b83c03b2328..62a8ddc5d1a44 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -73,7 +73,6 @@ class MoRIIOConstants: try: - import mori from mori.io import ( BackendType, EngineDesc, @@ -260,7 +259,8 @@ class MoRIIOConfig: class MoRIIOWriter: - """Handles write operations for KV cache transfers. Implements distributed KV cache transfer using the MoRIIO library + """Handles write operations for KV cache transfers. + Implements distributed KV cache transfer using the MoRIIO library for RDMA-based communication between prefill and decode instances.""" def __init__(self, worker: "MoRIIOConnectorWorker"): @@ -400,7 +400,9 @@ class MoRIIOWriter: return # Wait for CUDA event - # The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. + # The attention computation of the current layer cannot + # overlap with the kv transfer task, + # otherwise it will cause precision issues. # This event is used to synchronize the kv transfer and computation tasks. task.event.synchronize() @@ -497,8 +499,8 @@ class MoRIIOWriter: remote_port = task.remote_notify_port + get_port_offset( request_info.decode_dp_rank, self.worker.tp_rank ) - # TODO: - # Consider using RDMA immediate data in decode side to eliminate the need for this notification. + # Consider using RDMA immediate data in decode side + # to eliminate the need for this notification. # Consider including the first gen token from prefill in the notification # Send completion notification @@ -799,7 +801,11 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): def __repr__(self): return_str = "" for req_id, req_meta in self.reqs_to_recv.items(): - return_str += f"{req_id = },{req_meta.local_block_ids = },{req_meta.remote_block_ids = },{req_meta.remote_host = },{req_meta.remote_port = },{req_meta.remote_engine_id = },{req_meta.tp_size = }" + return_str += ( + f"{req_id = },{req_meta.local_block_ids = }," + f"{req_meta.remote_host = },{req_meta.remote_port = }" + f"{req_meta.remote_engine_id = },{req_meta.tp_size = }" + ) return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," for req_id, expiry in self.reqs_to_send.items(): @@ -862,7 +868,7 @@ class MoRIIOConnector(KVConnectorBase_V1): self.connector_scheduler = None self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) logger.info( - f"Initialized MoRIIO Connector,engine_id: {self.engine_id},role: {role.value}" + "Initialized MoRIIO Connector,engine_id:{self.engine_id},role: {role.value}" ) ############################################################ @@ -1067,7 +1073,6 @@ class MoRIIOConnectorScheduler: "decode_rank": self.dp_rank, "type": "remote_blocks", } - # logger.debug(f"MoRIIO send notify block for prefill, {data= },{host= },{port= }") serialized_data = msgpack.dumps(data) self.paths[path].send(serialized_data) @@ -1139,7 +1144,8 @@ class MoRIIOConnectorScheduler: meta = MoRIIOConnectorMetadata() if self.mode == MoRIIOMode.WRITE: - # when async_load_kv finished, will add new reqs to scheduler_output.scheduled_new_reqs + # when async_load_kv finished, + # new reqs will be added to scheduler_output.scheduled_new_reqs if get_role() == ROLE.CONSUMER: for new_req in scheduler_output.scheduled_new_reqs: @@ -1161,7 +1167,8 @@ class MoRIIOConnectorScheduler: ) if get_role() == ROLE.PRODUCER: # This is the logic for checking against chunked prefill. - # When the last chunk is identified, it places the request metadata into the saving queue. + # When the last chunk is identified, + # It places the request metadata into the saving queue. for i, req_id in enumerate( scheduler_output.scheduled_cached_reqs.req_ids @@ -1376,11 +1383,10 @@ class MoRIIOConnectorWorker: self._ping_thread.start() logger.info( - f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},role = {'producer' if self.is_producer else 'consumer'}" - ) - logger.debug( - f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.local_ping_port = },{self.proxy_ping_port = }" + f"Initializing MoRIIO Engine ,engine = {self.moriio_engine}," + f"role = {'producer' if self.is_producer else 'consumer'}" ) + # Agent. self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) self.moriio_wrapper.set_moriio_engine(self.moriio_engine) @@ -1568,7 +1574,8 @@ class MoRIIOConnectorWorker: except ConnectionRefusedError: logger.info( - f"Connection refused: {self.local_ip}:{self.local_ping_port} -> " + f"Connection refused: {self.local_ip}:" + f"{self.local_ping_port} -> " f"{self.proxy_ip}:{self.proxy_ping_port}" ) retry_count += 1 @@ -1584,7 +1591,8 @@ class MoRIIOConnectorWorker: finally: if retry_count >= MoRIIOConstants.MAX_PING_RETRIES: logger.error( - f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop." + f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES})" + "exceeded. Stopping ping loop." ) break @@ -1718,12 +1726,14 @@ class MoRIIOConnectorWorker: ) if len(self.local_kv_cache_metadata) > 0: logger.warning( - f"{len(self.local_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" + f"{len(self.local_kv_cache_metadata) = }," + "maybe you didnt clear this buffer correctly" ) self.local_kv_cache_metadata = [] if len(self.remote_kv_cache_metadata) > 0: logger.warning( - f" {len(self.remote_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" + f" {len(self.remote_kv_cache_metadata) = }," + "maybe you didnt clear this buffer correctly" ) self.remote_kv_cache_metadata = [] @@ -1798,7 +1808,6 @@ class MoRIIOConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in moriio.""" - # kv_caches,KEY layer name,VALUE cache tensor,(2,numblocks,blocksize,headnum,headsize) _, first_kv_cache = next(iter(kv_caches.items())) kv_elem_size = first_kv_cache.element_size() @@ -1836,8 +1845,6 @@ class MoRIIOConnectorWorker: self.block_shape = block_shape self.kv_element_size = kv_elem_size - # logger.info(f"Registering KV_Caches: {use_mla=}, {self.num_blocks=}, {block_shape=}, per_layer_kv_cache_shape={first_kv_cache.shape}") - self.dst_num_blocks[self.engine_id] = self.num_blocks self.kv_caches = kv_caches # layer name to kv cache kv_caches_base_addr = [] @@ -1849,7 +1856,6 @@ class MoRIIOConnectorWorker: if use_mla or self._use_flashinfer else cache_or_caches ) - # logger.debug(f"prepare register local kv cache tensor for local mori io engine,{len(cache_list) = },{kv_caches.keys() = }") for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len From 795a305b1b9d94720d7568ff24387175d7fe99d2 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 04:49:34 +0000 Subject: [PATCH 23/62] fix format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 87 ++++++++++--------- 1 file changed, 47 insertions(+), 40 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 62a8ddc5d1a44..ce6bb4f824f6d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -640,11 +640,11 @@ class MoRIIOWrapper: status.Wait() if not status.Succeeded(): logger.error( - f"Transfer failed: {status.Message()}, Code: {status.Code()}" + "Transfer failed: %s, Code: %s", status.Message(), status.Code() ) raise TransferError("MoRIIO transfer failed!") except Exception as e: - logger.error(f"Transfer {status} failed: {e}") + logger.error("Transfer %s failed: %s", status, e) raise def async_wait_reqid(self): @@ -656,7 +656,7 @@ class MoRIIOWrapper: def _async_wait(): host = "*" path = make_zmq_path("tcp", host, self.notify_port) - logger.info(f"Node starting to listen notify from path = {path}") + logger.info("Node starting to listen notify from path = %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: while True: @@ -664,7 +664,7 @@ class MoRIIOWrapper: identity, msg = sock.recv_multipart() self._handle_message(msg) except Exception as e: - logger.error(f"Error processing message: {e}") + logger.error("Error processing message: %s", e) raise HandshakeError(f"Error processing message: {e}") from e self.notify_thread = threading.Thread( @@ -696,7 +696,7 @@ class MoRIIOWrapper: self._handle_completion_message(msg_str) handled = True except UnicodeDecodeError: - logger.warning(f"Received non-UTF8 message: {msg_str}") + logger.warning("Received non-UTF8 message: %s", msg_str) if not handled: raise MoRIIOError(f"Unhandled message format: {msg_str}") @@ -740,11 +740,13 @@ class MoRIIOWrapper: try: for req_id in req_list: if not isinstance(req_id, str): - logger.warning(f"Invalid req_id type: {type(req_id)}, expected str") + logger.warning( + "Invalid req_id type: %s, expected str", type(req_id) + ) continue sock.send(req_id.encode("utf-8")) except Exception as e: - logger.error(f"Failed to send notification to {path}: {e}") + logger.error("Failed to send notification to %s: %s", path, e) self.paths.pop(path, None) raise @@ -936,9 +938,8 @@ class MoRIIOConnector(KVConnectorBase_V1): def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None - if self.mode == MoRIIOMode.WRITE: - if get_role() == ROLE.CONSUMER: - self.connector_worker.moriio_wrapper.async_wait_reqid() + if self.mode == MoRIIOMode.WRITE and get_role() == ROLE.CONSUMER: + self.connector_worker.moriio_wrapper.async_wait_reqid() assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) @@ -999,7 +1000,7 @@ class MoRIIOConnectorScheduler: "handshake_port" ] ) - logger.info(f"Initializing MoRIIO Scheduler {engine_id = }") + logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id) self.side_notify_port = ( self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] @@ -1383,8 +1384,9 @@ class MoRIIOConnectorWorker: self._ping_thread.start() logger.info( - f"Initializing MoRIIO Engine ,engine = {self.moriio_engine}," - f"role = {'producer' if self.is_producer else 'consumer'}" + "Initializing MoRIIO Engine, engine = %s, role = %s", + self.moriio_engine, + "producer" if self.is_producer else "consumer", ) # Agent. @@ -1550,7 +1552,7 @@ class MoRIIOConnectorWorker: retry_count = 0 index = 1 - + should_break = True with zmq_context.socket(zmq.DEALER) as sock: sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}") @@ -1574,30 +1576,33 @@ class MoRIIOConnectorWorker: except ConnectionRefusedError: logger.info( - f"Connection refused: {self.local_ip}:" - f"{self.local_ping_port} -> " - f"{self.proxy_ip}:{self.proxy_ping_port}" + "Connection refused: %s:%s -> %s:%s", + self.local_ip, + self.local_ping_port, + self.proxy_ip, + self.proxy_ping_port, ) retry_count += 1 except OSError as e: - logger.info(f"OS error when sending ping: {e}") + logger.info("OS error when sending ping: %s", e) retry_count += 1 except Exception as e: - logger.info(f"Unexpected error when sending ping: {e}") + logger.info("Unexpected error when sending ping: %s", e) retry_count += 1 finally: if retry_count >= MoRIIOConstants.MAX_PING_RETRIES: logger.error( - f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES})" - "exceeded. Stopping ping loop." + "Max retries (%s) exceeded. Stopping ping loop.", + MoRIIOConstants.MAX_PING_RETRIES, ) - break - + should_break = True time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 + if should_break: + break def handle_proxy_request(self): if self.is_producer: @@ -1606,7 +1611,7 @@ class MoRIIOConnectorWorker: ) while True: socks = dict(self.poller.poll()) - logger.debug(f"handle_proxy_request: {socks = }") + logger.debug("handle_proxy_request: socks = %s", socks) if self.metadata_socket not in socks: continue @@ -1650,7 +1655,7 @@ class MoRIIOConnectorWorker: host = "*" path = make_zmq_path("tcp", host, base_port) - logger.debug(f" mori handshake starting listening on path: {path}") + logger.debug("mori handshake starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() @@ -1695,11 +1700,11 @@ class MoRIIOConnectorWorker: port_offset = get_port_offset(remote_dp_rank, self.tp_rank) path = make_zmq_path("tcp", host, port + port_offset) - logger.debug(f"handshake Querying metadata on path:{path}") + logger.debug("handshake Querying metadata on path: %s", path) # Send query for the request. with zmq_ctx(zmq.DEALER, path) as sock: - logger.info(f"prepare send msg INSTAZNCE: {path}") + logger.debug("prepare send msg INSTAZNCE: %s", path) sock.send(MoRIIOConstants.GET_META_MSG) received_frame = sock.recv_multipart() if len(received_frame) != 2 or received_frame[0] != b"": @@ -1719,21 +1724,26 @@ class MoRIIOConnectorWorker: metadata.agent_metadata ) - logger.info( - f"MoRIIO handshake: registered remote agent " - f"{remote_agent_name=} for engine ID " - f"{expected_engine_id=},f{path= }" + logger.debug( + "MoRIIO handshake: registered" + "remote agent %s for engine ID %s, path = %s", + remote_agent_name, + expected_engine_id, + path, ) + if len(self.local_kv_cache_metadata) > 0: logger.warning( - f"{len(self.local_kv_cache_metadata) = }," - "maybe you didnt clear this buffer correctly" + "len(self.local_kv_cache_metadata) = %s," + "maybe you didnt clear this buffer correctly", + len(self.local_kv_cache_metadata), ) self.local_kv_cache_metadata = [] if len(self.remote_kv_cache_metadata) > 0: logger.warning( - f" {len(self.remote_kv_cache_metadata) = }," - "maybe you didnt clear this buffer correctly" + "len(self.remote_kv_cache_metadata) = %s," + "maybe you didnt clear this buffer correctly", + len(self.remote_kv_cache_metadata), ) self.remote_kv_cache_metadata = [] @@ -1995,7 +2005,6 @@ class MoRIIOConnectorWorker: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - logger.info(f"*****background moriio {remote_engine_id = }") self._background_moriio_handshake( req_id, remote_engine_id, meta ) @@ -2106,9 +2115,7 @@ class MoRIIOConnectorWorker: ) def _is_last_layer(self, layer_name): - if layer_name == list(self.kv_caches.keys())[-1]: - return True - return False + return layer_name == list(self.kv_caches.keys())[-1] def merge_contiguous_blocks( self, @@ -2256,7 +2263,7 @@ class MoRIIOConnectorWorker: first_layer, local_block_ids, remote_block_ids ) - for layer_name in self.layer_name_to_local_kv_cache_metadata.keys(): + for layer_name in self.layer_name_to_local_kv_cache_metadata: sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( layer_name ) From 857d93cbfbe0189fd0e7f5746a6740e906da4255 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 05:05:07 +0000 Subject: [PATCH 24/62] fix all commit Signed-off-by: inkcherry --- .../moriio_integration/toy_proxy_server.py | 10 ++++++---- .../kv_transfer/kv_connector/v1/moriio_connector.py | 5 ++--- vllm/entrypoints/openai/serving_chat.py | 1 + vllm/entrypoints/openai/serving_engine.py | 3 ++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py index 27c450ee7b25c..67ef2b02c76a9 100644 --- a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py @@ -16,12 +16,11 @@ from quart import Quart, make_response, request logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -prefill_instances = [] -decode_instances = [] +prefill_instances: list[dict] = [] +decode_instances: list[dict] = [] request_nums = 0 app = Quart(__name__) -yield_chunk = set() IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)") @@ -200,7 +199,10 @@ async def handle_request(): request_nums += 1 def extract_ip_port_fast(url): - return IP_PORT_PATTERN.search(url).groups() + match = IP_PORT_PATTERN.search(url) + if not match: + raise ValueError(f"Invalid URL format: {url}") + return match.groups() req_data = await request.get_json() request_id = str(uuid.uuid4()) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index ce6bb4f824f6d..df92a71702968 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -4,7 +4,6 @@ import contextlib import logging import math import os -import pickle import queue import threading import time @@ -1673,7 +1672,7 @@ class MoRIIOConnectorWorker: ) # send local mori io engine meta data logger.debug("MoRIIO handshake listener sent metadata") # now we send tensor meta data for each block - buf = pickle.dumps(layer_name_to_local_kv_cache_metadata) + buf = msgpack.dumps(layer_name_to_local_kv_cache_metadata) sock.send_multipart((identity, b"", buf)) elif msg == MoRIIOConstants.POP_DONE_RECV: _, req_id = sock.recv_multipart() @@ -1752,7 +1751,7 @@ class MoRIIOConnectorWorker: assert 0, f"Unexpected frame! {received_frame = }" buf = received_frame[1] self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( - pickle.loads(buf) + msgpack.loads(buf) ) setup_agent_time = time.perf_counter() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 59e1c8d531793..5a95721967ff0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -328,6 +328,7 @@ class OpenAIServingChat(OpenAIServing): lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + data_parallel_rank=data_parallel_rank, ) generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 4d9903c9c5745..102237dd65e05 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1172,7 +1172,7 @@ class OpenAIServing: lora_request: LoRARequest | None, trace_headers: Mapping[str, str] | None, priority: int, - data_parallel_rank: int, + data_parallel_rank: int | None, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} @@ -1220,6 +1220,7 @@ class OpenAIServing: lora_request=lora_request, trace_headers=trace_headers, priority=priority, + data_parallel_rank=None, ) generator = self.engine_client.generate( From 96da87bfe050a074c40d584760668de50a176819 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 05:10:58 +0000 Subject: [PATCH 25/62] refine Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index df92a71702968..877811a64a8d3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1014,7 +1014,7 @@ class MoRIIOConnectorScheduler: self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} # For chunked prefill, we perform layer-wise access within the final chunk. - # TODO: Perform access at the end of each chunk. + # TODO: Perform transfer at end chunk. self._reqs_need_pending_save: dict[ReqId, tuple[Request, list[int]]] = {} if self.is_producer: @@ -1461,9 +1461,6 @@ class MoRIIOConnectorWorker: self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - # List of block window sizes for each layer for local attention self.block_window_per_layer: list[int | None] = [] self.use_mla = self.model_config.use_mla self.built_session = False @@ -1775,8 +1772,6 @@ class MoRIIOConnectorWorker: tp_size = int(meta.tp_size) remote_dp_size = int(meta.remote_dp_size) - # TODO: handle failure state of future in the - # callback, we want to fail the request in this case. def request_ready(_f: Future[Any], entry=(req_id, meta)): logger.info("MoRIIO handshake done for request %s", req_id) self._ready_requests.put(entry) @@ -1998,8 +1993,7 @@ class MoRIIOConnectorWorker: meta.remote_engine_id = remote_engine_id - # TODO: mz get_remote_engine_id() for engine_id mapping. - dp0_remote_engine_id = f"{remote_engine_id}_dp0" + self.get_engine_name_with_dp(remote_engine_id, 0) if dp0_remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: From 9d29f361fb760e47d10f53f919585a48f42c5dd7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 05:14:08 +0000 Subject: [PATCH 26/62] update Signed-off-by: inkcherry --- .../distributed/kv_transfer/kv_connector/v1/moriio_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 877811a64a8d3..068bc6f096297 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1993,7 +1993,7 @@ class MoRIIOConnectorWorker: meta.remote_engine_id = remote_engine_id - self.get_engine_name_with_dp(remote_engine_id, 0) + dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id, 0) if dp0_remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: From 0a3ae0b0cc43173753fbd3d5f34722423c397ce0 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 05:19:34 +0000 Subject: [PATCH 27/62] update Signed-off-by: inkcherry --- .../distributed/kv_transfer/kv_connector/v1/moriio_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 068bc6f096297..78848f0ed2ed7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -2232,7 +2232,7 @@ class MoRIIOConnectorWorker: w += 1 merged_l, merged_r, merged_s = self.merge_contiguous_blocks( - offset_local, offset_remote, sizes, assume_sorted=True + offset_local, offset_remote, sizes, assume_sorted=False ) return merged_l, merged_r, merged_s From fd63437837f9a84a1ba91667a5efe63f2d381cc0 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 06:14:40 +0000 Subject: [PATCH 28/62] update Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 78848f0ed2ed7..9cefa0cd835d9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -869,7 +869,8 @@ class MoRIIOConnector(KVConnectorBase_V1): self.connector_scheduler = None self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) logger.info( - "Initialized MoRIIO Connector,engine_id:{self.engine_id},role: {role.value}" + "Initialized MoRIIO Connector,engine_id:%s,role: %s", + self.engine_id, role.value ) ############################################################ From 38d51f6dd8eabe549ebf23ec55fbecde4fb1576a Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 11:51:51 +0000 Subject: [PATCH 29/62] refine code Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 53 ++++++++++++------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 9cefa0cd835d9..931462f3c3dde 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -37,7 +37,7 @@ from vllm.distributed.parallel_state import ( ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket, get_open_port from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -69,7 +69,8 @@ class MoRIIOConstants: PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 - + DEFAULT_HANDSHAKE_PORT = "6301" + DEFAULT_NOTIFY_PORT="61005" try: from mori.io import ( @@ -78,6 +79,8 @@ try: IOEngine, IOEngineConfig, MemoryDesc, + PollCqMode, + RdmaBackendConfig ) logger.info("MoRIIO is available") @@ -192,11 +195,12 @@ class TransferError(MoRIIOError): def get_moriio_mode() -> MoRIIOMode: read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower() - # logger.info(f"MoRIIO Connector Read Mode = {read_mode}") + logger.debug("MoRIIO Connector read_mode: %s", read_mode) if read_mode in ("true", "1", "yes", "on"): return MoRIIOMode.READ else: return MoRIIOMode.WRITE + def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: @@ -220,19 +224,21 @@ class MoRIIOConfig: @classmethod def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": + # Port Configuration: - # local_ping_port -> Outgoing heartbeat to proxy(only rank0 need it) + # local_ping_port -> Outgoing heartbeat to proxy # proxy_ping_port -> Remote proxy's heartbeat ingress port # http_port -> Instance's HTTP service endpoint - # local_kv_port -> KV service port for Mori engine - # notify_port -> For synchronizing stages between nodes + # local_kv_port -> service port for mori engine + # notify_port -> For synchronizing stages between prefill and decode + # handshake_port -> For initial handshake between mori engine + #TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports + kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() dp_rank = vllm_config.parallel_config.data_parallel_rank - base_kv_port = int(kv_transfer_config.kv_port) - base_ping_port = int(extra_config["local_ping_port"]) base_notify_port = int(extra_config["notify_port"]) dp_size = vllm_config.parallel_config.data_parallel_size tp_size = get_tensor_model_parallel_world_size() @@ -240,9 +246,9 @@ class MoRIIOConfig: return cls( local_ip=get_ip(), - local_kv_port=base_kv_port + port_offset, + local_kv_port=get_open_port(), proxy_ip=extra_config["proxy_ip"], - local_ping_port=base_ping_port + port_offset, + local_ping_port=get_open_port(), proxy_ping_port=int(extra_config["proxy_ping_port"]), http_port=int(extra_config["http_port"]), handshake_port=int(extra_config["handshake_port"]), @@ -545,7 +551,7 @@ class MoRIIOWrapper: self.notify_thread = None self.sock = None self.sessions: list[IOEngine.Session] = [] - self.paths = {} + self.paths: dict[str, zmq.Socket] = {} def set_moriio_engine(self, moriio_engine): assert moriio_engine is not None, ( @@ -554,7 +560,17 @@ class MoRIIOWrapper: self.moriio_engine = moriio_engine def set_backend_type(self, backend_type): - self.moriio_engine.create_backend(backend_type) + qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1")) + post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1")) + num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1")) + poll_mode = PollCqMode.POLLING + rdma_cfg = RdmaBackendConfig( + qp_per_transfer, + post_batch_size, + num_worker_threads, + poll_mode, + ) + self.moriio_engine.create_backend(backend_type, rdma_cfg) def get_agent_metadata(self): engine_metadata = self.moriio_engine.get_engine_desc() @@ -700,6 +716,8 @@ class MoRIIOWrapper: raise MoRIIOError(f"Unhandled message format: {msg_str}") def _handle_structured_message(self, data: dict): + + assert get_role()==ROLE.PRODUCER, "Only prefill can get block messages" req_id = data["req_id"] block_notify_list = data.get("block_notify_list", []) decode_dp_rank = data.get("decode_rank", 0) @@ -882,16 +900,12 @@ class MoRIIOConnector(KVConnectorBase_V1): extra_config = kv_transfer_config.kv_connector_extra_config if "handshake_port" not in extra_config or not extra_config["handshake_port"]: - extra_config["handshake_port"] = "6301" + extra_config["handshake_port"] = MoRIIOConstants.DEFAULT_HANDSHAKE_PORT if "notify_port" not in extra_config or not extra_config["notify_port"]: - extra_config["notify_port"] = "61005" + extra_config["notify_port"] = MoRIIOConstants.DEFAULT_NOTIFY_PORT - if "local_ping_port" not in extra_config or not extra_config["local_ping_port"]: - extra_config["local_ping_port"] = "7583" - if not kv_transfer_config.kv_port: - kv_transfer_config.kv_port = "7305" def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -1180,7 +1194,7 @@ class MoRIIOConnectorScheduler: if new_block_ids is not None: block_ids = new_block_ids[0] - + #TODO : hybrid attn, etc req, existing_blocks = self._reqs_need_pending_save[req_id] updated_blocks = list(existing_blocks) + (block_ids) self._reqs_need_pending_save[req_id] = (req, updated_blocks) @@ -2261,6 +2275,7 @@ class MoRIIOConnectorWorker: sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( layer_name ) + #TODO : apply multi-session batch-read when moriio support it transfer_status = self.moriio_wrapper.read_remote_data( offs[2], offs[0], offs[1], sessions[sess_idx] ) From 72ccb5d77c69c0f515c470fe8667b92516872c2e Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 12:02:18 +0000 Subject: [PATCH 30/62] remove handle_proxy_request Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 931462f3c3dde..11552e5b460a3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1361,22 +1361,12 @@ class MoRIIOConnectorWorker: f":tp {self.tp_rank}:dp {self.dp_rank}" ) if not self.is_producer: - self.poller = zmq.Poller() - self.metadata_socket = self.zmq_context.socket(zmq.ROUTER) - self.metadata_socket.bind(f"tcp://{self.metadata_address}") - self.poller.register(self.metadata_socket, zmq.POLLIN) - self.moriio_engine = IOEngine( "consumer:" + engine_suffix, IOEngineConfig( self.moriio_config.local_ip, self.moriio_config.local_kv_port ), ) - - self._handle_request_thread = threading.Thread( - target=self.handle_proxy_request, daemon=True - ) - self._handle_request_thread.start() else: self.moriio_engine = IOEngine( "producer:" + engine_suffix, @@ -1384,7 +1374,6 @@ class MoRIIOConnectorWorker: self.moriio_config.local_ip, self.moriio_config.local_kv_port ), ) - logger.debug( "build MORI IOEngine %s:%s", self.moriio_config.local_ip, @@ -1609,23 +1598,23 @@ class MoRIIOConnectorWorker: "Max retries (%s) exceeded. Stopping ping loop.", MoRIIOConstants.MAX_PING_RETRIES, ) - should_break = True + should_break = True time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 if should_break: break - def handle_proxy_request(self): - if self.is_producer: - raise NotImplementedError( - "prefill instance doesn't need to send kv cache in pull mode" - ) - while True: - socks = dict(self.poller.poll()) - logger.debug("handle_proxy_request: socks = %s", socks) + # def handle_proxy_request(self): + # if self.is_producer: + # raise NotImplementedError( + # "prefill instance doesn't need to send kv cache in pull mode" + # ) + # while True: + # socks = dict(self.poller.poll()) + # logger.debug("handle_proxy_request: socks = %s", socks) - if self.metadata_socket not in socks: - continue + # if self.metadata_socket not in socks: + # continue def close(self): if hasattr(self, "_handshake_initiation_executor"): From 4776e2ddcf57af4a50f36122102c46e4a116c89c Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 12:13:24 +0000 Subject: [PATCH 31/62] more Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 11552e5b460a3..7ac303331146f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -37,7 +37,12 @@ from vllm.distributed.parallel_state import ( ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket, get_open_port +from vllm.utils.network_utils import ( + get_ip, + get_open_port, + make_zmq_path, + make_zmq_socket, +) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -70,7 +75,8 @@ class MoRIIOConstants: PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 DEFAULT_HANDSHAKE_PORT = "6301" - DEFAULT_NOTIFY_PORT="61005" + DEFAULT_NOTIFY_PORT = "61005" + try: from mori.io import ( @@ -80,7 +86,7 @@ try: IOEngineConfig, MemoryDesc, PollCqMode, - RdmaBackendConfig + RdmaBackendConfig, ) logger.info("MoRIIO is available") @@ -200,7 +206,6 @@ def get_moriio_mode() -> MoRIIOMode: return MoRIIOMode.READ else: return MoRIIOMode.WRITE - def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: @@ -224,7 +229,6 @@ class MoRIIOConfig: @classmethod def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": - # Port Configuration: # local_ping_port -> Outgoing heartbeat to proxy # proxy_ping_port -> Remote proxy's heartbeat ingress port @@ -233,8 +237,8 @@ class MoRIIOConfig: # notify_port -> For synchronizing stages between prefill and decode # handshake_port -> For initial handshake between mori engine - #TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports - + # TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports + kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() @@ -716,8 +720,7 @@ class MoRIIOWrapper: raise MoRIIOError(f"Unhandled message format: {msg_str}") def _handle_structured_message(self, data: dict): - - assert get_role()==ROLE.PRODUCER, "Only prefill can get block messages" + assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages" req_id = data["req_id"] block_notify_list = data.get("block_notify_list", []) decode_dp_rank = data.get("decode_rank", 0) @@ -887,8 +890,9 @@ class MoRIIOConnector(KVConnectorBase_V1): self.connector_scheduler = None self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) logger.info( - "Initialized MoRIIO Connector,engine_id:%s,role: %s", - self.engine_id, role.value + "Initialized MoRIIO Connector,engine_id:%s,role: %s", + self.engine_id, + role.value, ) ############################################################ @@ -905,8 +909,6 @@ class MoRIIOConnector(KVConnectorBase_V1): if "notify_port" not in extra_config or not extra_config["notify_port"]: extra_config["notify_port"] = MoRIIOConstants.DEFAULT_NOTIFY_PORT - - def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -1194,7 +1196,7 @@ class MoRIIOConnectorScheduler: if new_block_ids is not None: block_ids = new_block_ids[0] - #TODO : hybrid attn, etc + # TODO : hybrid attn, etc req, existing_blocks = self._reqs_need_pending_save[req_id] updated_blocks = list(existing_blocks) + (block_ids) self._reqs_need_pending_save[req_id] = (req, updated_blocks) @@ -1356,26 +1358,17 @@ class MoRIIOConnectorWorker: self._ping_thread = None self._writer = MoRIIOWriter(self) - engine_suffix = ( - f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}" - f":tp {self.tp_rank}:dp {self.dp_rank}" + role = "producer" if self.is_producer else "consumer" + engine_suffix = f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}:tp{self.tp_rank}:dp{self.dp_rank}" + self.moriio_engine = IOEngine( + f"{role}:{engine_suffix}", + IOEngineConfig( + self.moriio_config.local_ip, self.moriio_config.local_kv_port + ), ) - if not self.is_producer: - self.moriio_engine = IOEngine( - "consumer:" + engine_suffix, - IOEngineConfig( - self.moriio_config.local_ip, self.moriio_config.local_kv_port - ), - ) - else: - self.moriio_engine = IOEngine( - "producer:" + engine_suffix, - IOEngineConfig( - self.moriio_config.local_ip, self.moriio_config.local_kv_port - ), - ) logger.debug( - "build MORI IOEngine %s:%s", + "build MORI IOEngine %s (ip=%s port=%s)", + f"{role}:{engine_suffix}", self.moriio_config.local_ip, self.moriio_config.local_kv_port, ) @@ -1604,18 +1597,6 @@ class MoRIIOConnectorWorker: if should_break: break - # def handle_proxy_request(self): - # if self.is_producer: - # raise NotImplementedError( - # "prefill instance doesn't need to send kv cache in pull mode" - # ) - # while True: - # socks = dict(self.poller.poll()) - # logger.debug("handle_proxy_request: socks = %s", socks) - - # if self.metadata_socket not in socks: - # continue - def close(self): if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) @@ -2264,7 +2245,7 @@ class MoRIIOConnectorWorker: sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( layer_name ) - #TODO : apply multi-session batch-read when moriio support it + # TODO : apply multi-session batch-read when moriio support it transfer_status = self.moriio_wrapper.read_remote_data( offs[2], offs[0], offs[1], sessions[sess_idx] ) From b29f405aa5184353acb2e331e201ba0cdf949db7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 12:23:48 +0000 Subject: [PATCH 32/62] update Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 7ac303331146f..461a0c96345c6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -77,6 +77,7 @@ class MoRIIOConstants: DEFAULT_HANDSHAKE_PORT = "6301" DEFAULT_NOTIFY_PORT = "61005" + VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 try: from mori.io import ( @@ -237,7 +238,8 @@ class MoRIIOConfig: # notify_port -> For synchronizing stages between prefill and decode # handshake_port -> For initial handshake between mori engine - # TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports + # TODO : merge notify_port and handshake_port to simplify port management + # supports non-contiguous ports kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config @@ -1289,7 +1291,7 @@ class MoRIIOConnectorScheduler: if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion self._reqs_need_send[request.request_id] = ( - time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + time.perf_counter() + MoRIIOConstants.VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT ) # If we execute in P-D serial mode, no notification port is needed. @@ -1359,7 +1361,10 @@ class MoRIIOConnectorWorker: self._writer = MoRIIOWriter(self) role = "producer" if self.is_producer else "consumer" - engine_suffix = f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}:tp{self.tp_rank}:dp{self.dp_rank}" + engine_suffix = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}:" + f"tp{self.tp_rank}:dp{self.dp_rank}" + ) self.moriio_engine = IOEngine( f"{role}:{engine_suffix}", IOEngineConfig( @@ -1501,6 +1506,11 @@ class MoRIIOConnectorWorker: remote_ip: IP address of remote node """ + + # synchronization to prevent dirty reads between transfer and attention operations + # we can consider removing this synchronization after ibgda is enabled. + # when mori-io supports ibgda functionality + stream = torch.cuda.current_stream() event = torch.cuda.Event() event.record(stream) @@ -1931,8 +1941,6 @@ class MoRIIOConnectorWorker: else: done_recving = self._pop_done_transfers() else: - if self.mode == MoRIIOMode.WRITE: - self.moriio_wrapper.async_wait_reqid() done_sending, done_recving = ( set(), self.moriio_wrapper.pop_finished_write_req_ids(), From 1c10f47dc65fb32be00a9309bac1f961a64fe2ba Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 12:49:22 +0000 Subject: [PATCH 33/62] tp write single pass Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 461a0c96345c6..eb381c392d3ef 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -73,7 +73,7 @@ class MoRIIOConstants: COMPLETION_PREFIX = "cmpl" PING_INTERVAL = 5 - MAX_PING_RETRIES = 100000 + MAX_PING_RETRIES = 1000000 DEFAULT_HANDSHAKE_PORT = "6301" DEFAULT_NOTIFY_PORT = "61005" @@ -1932,7 +1932,7 @@ class MoRIIOConnectorWorker: to track which workers are done. """ - done_sending = set() + done_sending, done_recving = set() ,set() if self.is_producer: done_sending = self.moriio_wrapper.pop_finished_req_ids() From 16d2a7a34355a654af9519bf34849df1137826ff Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 15:03:06 +0000 Subject: [PATCH 34/62] updata finished request collection Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index eb381c392d3ef..d737cf704b368 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -518,6 +518,8 @@ class MoRIIOWriter: self.worker.moriio_wrapper.send_notify( task.request_id, task.remote_ip, remote_port ) + # mark request as done, then we can free the blocks + self.worker.moriio_wrapper.done_req_ids.append(task.request_id) del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ task.request_id ] @@ -1936,15 +1938,13 @@ class MoRIIOConnectorWorker: if self.is_producer: done_sending = self.moriio_wrapper.pop_finished_req_ids() + + else: if self.mode == MoRIIOMode.WRITE: - done_recving = set() + done_recving = self.moriio_wrapper.pop_finished_write_req_ids() else: done_recving = self._pop_done_transfers() - else: - done_sending, done_recving = ( - set(), - self.moriio_wrapper.pop_finished_write_req_ids(), - ) + return done_sending, done_recving From 374cc25e0f3e4a28a9b5e850beb3e1b9063bc802 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 24 Nov 2025 04:54:48 +0000 Subject: [PATCH 35/62] format Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index d737cf704b368..1ae239ce330e7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -79,6 +79,7 @@ class MoRIIOConstants: VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 + try: from mori.io import ( BackendType, @@ -1293,7 +1294,8 @@ class MoRIIOConnectorScheduler: if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion self._reqs_need_send[request.request_id] = ( - time.perf_counter() + MoRIIOConstants.VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT + time.perf_counter() + + MoRIIOConstants.VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT ) # If we execute in P-D serial mode, no notification port is needed. @@ -1508,8 +1510,8 @@ class MoRIIOConnectorWorker: remote_ip: IP address of remote node """ - - # synchronization to prevent dirty reads between transfer and attention operations + # synchronization to prevent dirty reads between + # transfer and attention operations # we can consider removing this synchronization after ibgda is enabled. # when mori-io supports ibgda functionality @@ -1934,18 +1936,17 @@ class MoRIIOConnectorWorker: to track which workers are done. """ - done_sending, done_recving = set() ,set() + done_sending, done_recving = set(), set() if self.is_producer: done_sending = self.moriio_wrapper.pop_finished_req_ids() - + else: if self.mode == MoRIIOMode.WRITE: done_recving = self.moriio_wrapper.pop_finished_write_req_ids() else: done_recving = self._pop_done_transfers() - return done_sending, done_recving def _pop_done_transfers(self) -> set[str]: From 77321502e73ecdbf0ea2b68f518f3b3ec481e37b Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 24 Nov 2025 06:53:45 +0000 Subject: [PATCH 36/62] update lock Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 1ae239ce330e7..143a15cf6db07 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -520,7 +520,8 @@ class MoRIIOWriter: task.request_id, task.remote_ip, remote_port ) # mark request as done, then we can free the blocks - self.worker.moriio_wrapper.done_req_ids.append(task.request_id) + with self.worker.moriio_wrapper.lock: + self.worker.moriio_wrapper.done_req_ids.append(task.request_id) del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ task.request_id ] @@ -1559,7 +1560,6 @@ class MoRIIOConnectorWorker: retry_count = 0 index = 1 - should_break = True with zmq_context.socket(zmq.DEALER) as sock: sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}") @@ -1598,18 +1598,17 @@ class MoRIIOConnectorWorker: except Exception as e: logger.info("Unexpected error when sending ping: %s", e) retry_count += 1 - - finally: if retry_count >= MoRIIOConstants.MAX_PING_RETRIES: logger.error( "Max retries (%s) exceeded. Stopping ping loop.", MoRIIOConstants.MAX_PING_RETRIES, ) - should_break = True + raise RuntimeError(f"Ping failed after {retry_count} retries") + + finally: time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 - if should_break: - break + def close(self): if hasattr(self, "_handshake_initiation_executor"): @@ -1951,19 +1950,23 @@ class MoRIIOConnectorWorker: def _pop_done_transfers(self) -> set[str]: done_req_ids: set[str] = set() - for req_id, status_list in self._recving_transfers.items(): - if status_list[-1].Succeeded(): - done_req_ids.add(req_id) + with self.moriio_wrapper.lock: + to_remove = [] + for req_id, status_list in self._recving_transfers.items(): + if status_list[-1].Succeeded(): + done_req_ids.add(req_id) - self.moriio_wrapper.send_notify( - req_id, - self._recving_transfers_callback_addr[req_id][0], - self._recving_transfers_callback_addr[req_id][1], - ) + self.moriio_wrapper.send_notify( + req_id, + self._recving_transfers_callback_addr[req_id][0], + self._recving_transfers_callback_addr[req_id][1], + ) + to_remove.append(req_id) + for req_id in to_remove: del self._recving_transfers[req_id] del self._recving_transfers_callback_addr[req_id] - return done_req_ids + return done_req_ids def save_kv_layer( self, @@ -2258,12 +2261,13 @@ class MoRIIOConnectorWorker: transfer_status = self.moriio_wrapper.read_remote_data( offs[2], offs[0], offs[1], sessions[sess_idx] ) + with self.moriio_wrapper.lock: - self._recving_transfers[request_id].append(transfer_status) - self._recving_transfers_callback_addr[request_id] = ( - remote_host, - str(remote_notify_port + self.tp_rank), - ) + self._recving_transfers[request_id].append(transfer_status) + self._recving_transfers_callback_addr[request_id] = ( + remote_host, + str(remote_notify_port + self.tp_rank), + ) @contextlib.contextmanager From 63e6cff196abc624ce4567df083381f86363f260 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 24 Nov 2025 08:53:51 +0000 Subject: [PATCH 37/62] update proxy path Signed-off-by: inkcherry --- .../disaggregated_serving/moriio_toy_proxy_server.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/v1/kv_connector/moriio_integration/toy_proxy_server.py => examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py (100%) diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py similarity index 100% rename from tests/v1/kv_connector/moriio_integration/toy_proxy_server.py rename to examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py From 7d3a93f1e7ac5565af923045ed3b374f9a2c6742 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 24 Nov 2025 09:02:04 +0000 Subject: [PATCH 38/62] format Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 143a15cf6db07..b24ee94508e84 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1603,12 +1603,13 @@ class MoRIIOConnectorWorker: "Max retries (%s) exceeded. Stopping ping loop.", MoRIIOConstants.MAX_PING_RETRIES, ) - raise RuntimeError(f"Ping failed after {retry_count} retries") + raise RuntimeError( + f"Ping failed after {retry_count} retries" + ) from e finally: time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 - def close(self): if hasattr(self, "_handshake_initiation_executor"): @@ -2262,7 +2263,6 @@ class MoRIIOConnectorWorker: offs[2], offs[0], offs[1], sessions[sess_idx] ) with self.moriio_wrapper.lock: - self._recving_transfers[request_id].append(transfer_status) self._recving_transfers_callback_addr[request_id] = ( remote_host, From b3b195a540a97dc99bc4cab830d4107bfea39f68 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 24 Nov 2025 09:18:49 +0000 Subject: [PATCH 39/62] update Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index b24ee94508e84..5135580924f14 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -20,7 +20,6 @@ import numpy as np import torch import zmq -from vllm import envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig @@ -1014,9 +1013,8 @@ class MoRIIOConnectorScheduler: self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id - self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.mode = get_moriio_mode() - + self.host_ip = get_ip() self.handshake_port = ( self.vllm_config.kv_transfer_config.kv_connector_extra_config[ "handshake_port" @@ -1305,7 +1303,7 @@ class MoRIIOConnectorScheduler: do_remote_decode=False, remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, - remote_host=self.side_channel_host, + remote_host=self.host_ip, remote_port=self.handshake_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, ) From ad5678b0564db13be0a87c51952425af722d4f65 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 27 Nov 2025 07:23:48 +0000 Subject: [PATCH 40/62] format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 69 ++++++++++++------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 5135580924f14..b1063f5487955 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -240,7 +240,9 @@ class MoRIIOConfig: # TODO : merge notify_port and handshake_port to simplify port management # supports non-contiguous ports - + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() @@ -745,12 +747,12 @@ class MoRIIOWrapper: else: self.done_write_cache_req_ids.append(msg) - def send_notify(self, req_ids, remote_ip=None, remote_port=None): + def send_notify(self, req_ids, remote_ip, remote_port): if not remote_ip or not remote_port: logger.warning("Missing remote_ip or remote_port for notification") return - path = make_zmq_path("tcp", remote_ip, str(remote_port)) + path = make_zmq_path("tcp", remote_ip, remote_port) if path not in self.paths: ctx = zmq.Context.instance() @@ -872,18 +874,18 @@ class MoRIIOConnector(KVConnectorBase_V1): kv_cache_config: Optional["KVCacheConfig"] = None, ): super().__init__(vllm_config, role) - assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + + self.kv_transfer_config = vllm_config.kv_transfer_config # assert vllm_config.kv_transfer_config.engine_id is not None self._set_port_defaults(vllm_config) self.engine_id = ( str(get_ip()) + ":" - + str( - vllm_config.kv_transfer_config.kv_connector_extra_config[ - "handshake_port" - ] - ) + + str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"]) ) self.mode = get_moriio_mode() if role == KVConnectorRole.SCHEDULER: @@ -905,6 +907,9 @@ class MoRIIOConnector(KVConnectorBase_V1): ############################################################ def _set_port_defaults(self, vllm_config: VllmConfig): + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config @@ -1011,23 +1016,26 @@ class MoRIIOConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + self.kv_transfer_config = vllm_config.kv_transfer_config self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id self.mode = get_moriio_mode() self.host_ip = get_ip() - self.handshake_port = ( - self.vllm_config.kv_transfer_config.kv_connector_extra_config[ - "handshake_port" - ] - ) + self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[ + "handshake_port" + ] logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id) - self.side_notify_port = ( - self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] - ) + self.side_notify_port = self.kv_transfer_config.kv_connector_extra_config[ + "notify_port" + ] self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank - self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" + self.is_producer = self.kv_transfer_config.kv_role == "kv_producer" # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. @@ -1045,7 +1053,6 @@ class MoRIIOConnectorScheduler: # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} self.sock = None - self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" self.paths: dict[str, zmq.Socket] = {} def get_num_new_matched_tokens( @@ -1070,12 +1077,13 @@ class MoRIIOConnectorScheduler: if self.is_producer: return 0, False + token_ids = request.prompt_token_ids or [] if self.mode == MoRIIOMode.WRITE: # MoriiO in write mode, no remote prefill - return len(request.prompt_token_ids) - num_computed_tokens, True + return len(token_ids) - num_computed_tokens, True - return len(request.prompt_token_ids) - 1 - num_computed_tokens, False + return len(token_ids) - 1 - num_computed_tokens, False def send_notify_block( self, req_id: str, block_notify_list: list[int], host=None, port=None @@ -1105,6 +1113,8 @@ class MoRIIOConnectorScheduler: connector_worker: Optional["MoRIIOConnectorWorker"] = None, ): params = request.kv_transfer_params + if not params: + return if params.get("do_remote_decode"): local_block_ids = blocks.get_block_ids()[0] self._reqs_need_save[request.request_id] = (request, local_block_ids) @@ -1140,6 +1150,10 @@ class MoRIIOConnectorScheduler: ) else: + assert request.kv_transfer_params is not None, ( + "kv_transfer_params should not be None" + ) + remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) for tp_index in range(self.tp_size): @@ -1178,9 +1192,11 @@ class MoRIIOConnectorScheduler: assert hasattr(new_req.sampling_params, "extra_args"), ( f"sampling_params missing extra_args for req {new_req.req_id}" ) - kv_transfer_params = new_req.sampling_params.extra_args[ - "kv_transfer_params" - ] + kv_transfer_params = ( + new_req.sampling_params.extra_args.get("kv_transfer_params", {}) + if new_req.sampling_params.extra_args + else {} + ) meta.add_new_req( red_id, local_block_ids, @@ -1212,7 +1228,7 @@ class MoRIIOConnectorScheduler: meta.add_new_req( request_id=req_id, local_block_ids=self._reqs_need_pending_save[req_id][1], - kv_transfer_params=req.kv_transfer_params, + kv_transfer_params=req.kv_transfer_params or {}, write_mode=True, ) del self._reqs_need_pending_save[req_id] @@ -1328,6 +1344,9 @@ class MoRIIOConnectorWorker: # Config. self.vllm_config = vllm_config + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) self.kv_transfer_config = vllm_config.kv_transfer_config self.is_producer = self.kv_transfer_config.is_kv_producer From ea9b6871f6c0326e49d0498e5c13609049f7b309 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 27 Nov 2025 11:50:41 +0000 Subject: [PATCH 41/62] update Signed-off-by: inkcherry --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 9e8f9de7341b0..7cda86478664f 100644 --- a/.gitignore +++ b/.gitignore @@ -227,4 +227,3 @@ ep_kernels_workspace/ # Allow tracked library source folders under submodules (e.g., benchmarks/lib) !vllm/benchmarks/lib/ -examples/online_serving/disaggregated_serving_p2p_moriio_xpyd/ From 536668602c3413ca137eca5d8d48b526abb895b4 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 06:25:53 +0000 Subject: [PATCH 42/62] fix format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index b1063f5487955..1299faff78336 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -544,23 +544,26 @@ class MoRIIOWrapper: dp_rank: Data parallel rank """ - def __init__(self, moriio_engine=None, tp_rank=0, dp_rank=0): + def __init__( + self, + moriio_engine: Optional["IOEngine"] = None, + tp_rank: int = 0, + dp_rank: int = 0, + ): self.tp_rank = tp_rank self.dp_rank = dp_rank self.moriio_engine = moriio_engine self.remote_memory_metadata = None self.local_memory_registered = False self.local_memory_metadata = None - self.transfer_status = [] - self.remote_engine_ip = None - self.notify_port = None - self.notify_sock = None + self.transfer_status: list[Any] = [] + self.remote_engine_ip: str | None = None + self.notify_port: int | None = None self.lock = threading.Lock() - self.done_req_ids = [] + self.done_req_ids: list[str] = [] self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} - self.done_write_cache_req_ids = [] - self.notify_thread = None - self.sock = None + self.done_write_cache_req_ids: list[str] = [] + self.notify_thread: threading.Thread | None = None self.sessions: list[IOEngine.Session] = [] self.paths: dict[str, zmq.Socket] = {} @@ -571,6 +574,7 @@ class MoRIIOWrapper: self.moriio_engine = moriio_engine def set_backend_type(self, backend_type): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1")) post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1")) num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1")) @@ -584,20 +588,26 @@ class MoRIIOWrapper: self.moriio_engine.create_backend(backend_type, rdma_cfg) def get_agent_metadata(self): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" engine_metadata = self.moriio_engine.get_engine_desc() engine_metadata_packed = engine_metadata.pack() return engine_metadata_packed def register_remote_engine(self, remote_packed_engine_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) self.moriio_engine.register_remote_engine(consumer_engine_metadata) return consumer_engine_metadata.key def register_local_tensor(self, tensor: torch.Tensor): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" try: self.local_memory_metadata = self.moriio_engine.register_torch_tensor( tensor ) + assert self.local_memory_metadata is not None, ( + "register_torch_tensor returned None" + ) local_memory_metadata_packed = self.local_memory_metadata.pack() except Exception as e: raise MoRIIOError(f"Failed to register local memory: {e}") from e @@ -608,6 +618,7 @@ class MoRIIOWrapper: return MemoryDesc.unpack(packed_memory_metadata) def build_session(self, local_memory_metadata, remote_memory_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" return self.moriio_engine.create_session( local_memory_metadata, remote_memory_metadata ) @@ -616,7 +627,7 @@ class MoRIIOWrapper: self, transfer_size_byte, local_offset=0, remote_offset=0, session=None ): assert self.local_memory_registered, "You have not register local memory data!" - + assert self.moriio_engine is not None, "MoRIIO engine must be set first" transfer_status = session.batch_read( local_offset, remote_offset, @@ -630,6 +641,7 @@ class MoRIIOWrapper: self, transfer_size_byte, local_offset=0, remote_offset=0, session=None ): assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" write_uid = self.moriio_engine.allocate_transfer_uid() transfer_status = session.batch_write( @@ -642,7 +654,7 @@ class MoRIIOWrapper: self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 ): assert self.local_memory_registered, "You have not register local memory data!" - + assert self.moriio_engine is not None, "MoRIIO engine must be set first" transfer_status = self.sessions[sess_idx].write( local_offset, remote_offset, @@ -1052,7 +1064,6 @@ class MoRIIOConnectorScheduler: set_role(ROLE.CONSUMER) # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} - self.sock = None self.paths: dict[str, zmq.Socket] = {} def get_num_new_matched_tokens( From 1742f0cdfba1107eb8ccbcc9b08ae447072d7b70 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 07:35:10 +0000 Subject: [PATCH 43/62] clean up code Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 1299faff78336..823ed56b23806 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1351,8 +1351,6 @@ class MoRIIOConnectorWorker: logger.info("Initializing MoRIIO worker %s", engine_id) - logging.getLogger("aiter").disabled = True - # Config. self.vllm_config = vllm_config assert vllm_config.kv_transfer_config is not None, ( @@ -1507,12 +1505,9 @@ class MoRIIOConnectorWorker: self.block_size, use_mla=self.use_mla, ) + + #TODO: consider the integration of flashinfer or other backends. self.backend_name = backend.get_name() - attn_backend = AttentionBackendEnum[self.backend_name] - self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS - # attn_backend = backend_name_to_enum(self.backend_name) - # self._use_flashinfer = attn_backend == _Backend.FLASHINFER logger.debug("Detected attention backend %s", self.backend_name) def schedule_write_blocks( @@ -1854,13 +1849,8 @@ class MoRIIOConnectorWorker: self.slot_size_bytes = kv_elem_size * kv_latent_dim else: # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] - else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] block_size, n_kv_heads, head_dim = block_shape[-3:] # head size in bytes. @@ -1884,7 +1874,7 @@ class MoRIIOConnectorWorker: for cache_or_caches in kv_caches.values(): cache_list = ( [cache_or_caches] - if use_mla or self._use_flashinfer + if use_mla else cache_or_caches ) for cache in cache_list: From bbe6dad4013bd6c83c58a933ad7fafb33212d7c8 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 07:56:12 +0000 Subject: [PATCH 44/62] use envs Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 12 +++++----- vllm/envs.py | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 823ed56b23806..5b8370a56e488 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -19,7 +19,7 @@ import msgspec import numpy as np import torch import zmq - +from vllm import envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig @@ -201,9 +201,9 @@ class TransferError(MoRIIOError): def get_moriio_mode() -> MoRIIOMode: - read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower() + read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE logger.debug("MoRIIO Connector read_mode: %s", read_mode) - if read_mode in ("true", "1", "yes", "on"): + if read_mode: return MoRIIOMode.READ else: return MoRIIOMode.WRITE @@ -575,9 +575,9 @@ class MoRIIOWrapper: def set_backend_type(self, backend_type): assert self.moriio_engine is not None, "MoRIIO engine must be set first" - qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1")) - post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1")) - num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1")) + qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER + post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE + num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS poll_mode = PollCqMode.POLLING rdma_cfg = RdmaBackendConfig( qp_per_transfer, diff --git a/vllm/envs.py b/vllm/envs.py index 56558548d3981..f3b212dbe59cb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -193,6 +193,10 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 + VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False + VLLM_MORIIO_QP_PER_TRANSFER:int=1 + VLLM_MORIIO_POST_BATCH_SIZE:int=-1 + VLLM_MORIIO_NUM_WORKERS:int=1 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False @@ -1343,6 +1347,25 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") ), + # Controls the read mode for the Mori-IO connector + "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( + os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() + in ("true", "1") + ), + # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector + "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( + os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") + ), + + # Controls the post-processing batch size for the Mori-IO connector + "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( + os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") + ), + + # Controls the number of workers for Mori operations for the Mori-IO connector + "VLLM_MORIIO_NUM_WORKERS": lambda: int( + os.getenv("VLLM_MORIIO_NUM_WORKERS", "1") + ), # Controls whether or not to use cudnn prefill "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) From 6fbeee78d1d43f57c8a546a9ca8896f29b38d50b Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 08:56:53 +0000 Subject: [PATCH 45/62] improve shutdown Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 38 +++++++++++++++++-- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 5b8370a56e488..781a76baab2bf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -726,6 +726,7 @@ class MoRIIOWrapper: return except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): + logger.debug("Failed to decode msgpack message, will try as string. Error: %s") pass try: @@ -802,6 +803,16 @@ class MoRIIOWrapper: done_write_cache = set(self.done_write_cache_req_ids) self.done_write_cache_req_ids = [] return done_write_cache + + def shutdown(self): + logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug(f"Closed ZMQ socket for path: {path}") + except Exception as e: + logger.warning(f"Error closing ZMQ socket for path {path}: {e}") + self.paths.clear() class MoRIIOAgentMetadata( @@ -1011,6 +1022,12 @@ class MoRIIOConnector(KVConnectorBase_V1): def wait_for_save(self): pass + def shutdown(self): + if self.connector_worker is not None: + self.connector_worker.shutdown() + if self.connector_scheduler is not None: + self.connector_scheduler.shutdown() + def has_connector_metadata(self) -> bool: """Check whether the connector metadata is currently set. @@ -1351,6 +1368,8 @@ class MoRIIOConnectorWorker: logger.info("Initializing MoRIIO worker %s", engine_id) + logging.getLogger("aiter").disabled = True + # Config. self.vllm_config = vllm_config assert vllm_config.kv_transfer_config is not None, ( @@ -1634,7 +1653,18 @@ class MoRIIOConnectorWorker: time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 - def close(self): + def shutdown(self): + if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper: + self.moriio_wrapper.shutdown() + + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug(f"Closed ZMQ socket for path: {path}") + except Exception as e: + logger.warning(f"Error closing ZMQ socket for path {path}: {e}") + self.paths.clear() + if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) @@ -1649,7 +1679,7 @@ class MoRIIOConnectorWorker: self.zmq_context = None def __del__(self): - self.close() + self.shutdown() @staticmethod def _moriio_handshake_listener( @@ -1726,7 +1756,7 @@ class MoRIIOConnectorWorker: sock.send(MoRIIOConstants.GET_META_MSG) received_frame = sock.recv_multipart() if len(received_frame) != 2 or received_frame[0] != b"": - assert 0, f"unexpected frame! {received_frame = }" + raise HandshakeError(f"Unexpected frame! {received_frame = }") metadata_bytes = received_frame[1] decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) @@ -1767,7 +1797,7 @@ class MoRIIOConnectorWorker: received_frame = sock.recv_multipart() if len(received_frame) != 2 or received_frame[0] != b"": - assert 0, f"Unexpected frame! {received_frame = }" + raise HandshakeError(f"unexpected frame! {received_frame = }") buf = received_frame[1] self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( msgpack.loads(buf) From bd6a3406525318357068cc2854f76dbf7dade3aa Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 09:07:49 +0000 Subject: [PATCH 46/62] format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 50 +++++++++---------- vllm/envs.py | 17 +++---- 2 files changed, 29 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 781a76baab2bf..c4ce7965c8b7a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -3,7 +3,6 @@ import contextlib import logging import math -import os import queue import threading import time @@ -19,8 +18,8 @@ import msgspec import numpy as np import torch import zmq + from vllm import envs -from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -201,7 +200,7 @@ class TransferError(MoRIIOError): def get_moriio_mode() -> MoRIIOMode: - read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE + read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE logger.debug("MoRIIO Connector read_mode: %s", read_mode) if read_mode: return MoRIIOMode.READ @@ -575,9 +574,9 @@ class MoRIIOWrapper: def set_backend_type(self, backend_type): assert self.moriio_engine is not None, "MoRIIO engine must be set first" - qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER - post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE - num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS + qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER + post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE + num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS poll_mode = PollCqMode.POLLING rdma_cfg = RdmaBackendConfig( qp_per_transfer, @@ -726,7 +725,7 @@ class MoRIIOWrapper: return except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): - logger.debug("Failed to decode msgpack message, will try as string. Error: %s") + logger.debug("Failed to decode msgpack message, will try as string") pass try: @@ -803,15 +802,15 @@ class MoRIIOWrapper: done_write_cache = set(self.done_write_cache_req_ids) self.done_write_cache_req_ids = [] return done_write_cache - + def shutdown(self): logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") for path, sock in self.paths.items(): try: sock.close(linger=0) - logger.debug(f"Closed ZMQ socket for path: {path}") + logger.debug("Closed ZMQ socket for path: %s", path) except Exception as e: - logger.warning(f"Error closing ZMQ socket for path {path}: {e}") + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) self.paths.clear() @@ -1292,6 +1291,15 @@ class MoRIIOConnectorScheduler: return meta + def shutdown(self): + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug("Closed ZMQ socket for path: %s", path) + except Exception as e: + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) + self.paths.clear() + def request_finished( self, request: "Request", @@ -1524,8 +1532,8 @@ class MoRIIOConnectorWorker: self.block_size, use_mla=self.use_mla, ) - - #TODO: consider the integration of flashinfer or other backends. + + # TODO: consider the integration of flashinfer or other backends. self.backend_name = backend.get_name() logger.debug("Detected attention backend %s", self.backend_name) @@ -1654,17 +1662,9 @@ class MoRIIOConnectorWorker: index += 1 def shutdown(self): - if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper: + if hasattr(self, "moriio_wrapper") and self.moriio_wrapper: self.moriio_wrapper.shutdown() - - for path, sock in self.paths.items(): - try: - sock.close(linger=0) - logger.debug(f"Closed ZMQ socket for path: {path}") - except Exception as e: - logger.warning(f"Error closing ZMQ socket for path {path}: {e}") - self.paths.clear() - + if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) @@ -1902,11 +1902,7 @@ class MoRIIOConnectorWorker: caches_data = [] for cache_or_caches in kv_caches.values(): - cache_list = ( - [cache_or_caches] - if use_mla - else cache_or_caches - ) + cache_list = [cache_or_caches] if use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len diff --git a/vllm/envs.py b/vllm/envs.py index f3b212dbe59cb..68bf5346fae7b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -193,10 +193,10 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 - VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False - VLLM_MORIIO_QP_PER_TRANSFER:int=1 - VLLM_MORIIO_POST_BATCH_SIZE:int=-1 - VLLM_MORIIO_NUM_WORKERS:int=1 + VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False + VLLM_MORIIO_QP_PER_TRANSFER: int = 1 + VLLM_MORIIO_POST_BATCH_SIZE: int = -1 + VLLM_MORIIO_NUM_WORKERS: int = 1 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False @@ -1349,23 +1349,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Controls the read mode for the Mori-IO connector "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( - os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() - in ("true", "1") + os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1") ), # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") ), - # Controls the post-processing batch size for the Mori-IO connector "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") ), - # Controls the number of workers for Mori operations for the Mori-IO connector - "VLLM_MORIIO_NUM_WORKERS": lambda: int( - os.getenv("VLLM_MORIIO_NUM_WORKERS", "1") - ), + "VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")), # Controls whether or not to use cudnn prefill "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) From 8a2c136c8ff704d8b8208bb212abd192c75a0a03 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 09:29:10 +0000 Subject: [PATCH 47/62] clean up Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio_connector.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index c4ce7965c8b7a..0debaca514d17 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -901,7 +901,6 @@ class MoRIIOConnector(KVConnectorBase_V1): ) self.kv_transfer_config = vllm_config.kv_transfer_config - # assert vllm_config.kv_transfer_config.engine_id is not None self._set_port_defaults(vllm_config) self.engine_id = ( @@ -2135,9 +2134,6 @@ class MoRIIOConnectorWorker: ) def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer): - # logger.debug(f"write block for req {req_id} to remote engine " - # f"{meta.remote_engine_id}") - self.schedule_write_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, From 532d8a7453415040f998c4e91fc87841ca1cff89 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 2 Dec 2025 06:08:03 +0000 Subject: [PATCH 48/62] Fix the issue of num_block inconsistency in non-MLA scenarios Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 0debaca514d17..7c566bc8c30ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -138,6 +138,19 @@ class ROLE(Enum): CONSUMER = "consumer" NOTINIT = "notinit" +class MoRIIOAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True, +): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + block_len: int + attn_backend_name: str + class RoleManager: """Manages role state across the connector.""" @@ -424,10 +437,10 @@ class MoRIIOWriter: ) # Get or create sessions - sessions = self.worker._get_built_session(task.dst_engine_id) + sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id) # Prepare transfer plan - plan = self._prepare_transfer_plan(task, request_info) + plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) # Execute transfer self._do_layer_write(plan, sessions) @@ -436,7 +449,7 @@ class MoRIIOWriter: self._finalize_if_complete(task, request_info) def _prepare_transfer_plan( - self, task: WriteTask, request_info: RemoteAllocInfo + self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata ) -> LayerTransferPlan: """Prepare the transfer plan for a layer. @@ -450,7 +463,7 @@ class MoRIIOWriter: # Compute offsets if not cached if request_info.transfer_offset is None: offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, task.local_block_ids, request_info.block_ids + task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta ) request_info.transfer_offset = offsets @@ -814,18 +827,6 @@ class MoRIIOWrapper: self.paths.clear() -class MoRIIOAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property.d - dict=True, -): - engine_id: str - agent_metadata: bytes - kv_caches_base_addr: list[int] - num_blocks: int - block_len: int - attn_backend_name: str @dataclass @@ -1461,6 +1462,7 @@ class MoRIIOConnectorWorker: self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( dict() ) + self.remote_moriio_metadata: dict[EngineId, MoRIIOAgentMetadata] = {} self.slot_size_bytes = 0 self.load_ready_flag: dict[str, bool] = {} @@ -1585,10 +1587,10 @@ class MoRIIOConnectorWorker: if remote_engine_id not in self.built_write_session: cur_remote_engine_sessions = [] for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items(): - unpcaked_local_memory_meta = ( + unpacked_local_memory_meta = ( self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0]) ) - unpcaked_remote_memory_meta = ( + unpacked_remote_memory_meta = ( self.moriio_wrapper.get_unpack_memory_metadata( self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][ ln @@ -1597,11 +1599,11 @@ class MoRIIOConnectorWorker: ) cur_remote_engine_sessions.append( self.moriio_wrapper.build_session( - unpcaked_local_memory_meta, unpcaked_remote_memory_meta + unpacked_local_memory_meta, unpacked_remote_memory_meta ) ) self.built_write_session[remote_engine_id] = cur_remote_engine_sessions - return self.built_write_session[remote_engine_id] + return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id] def _ping(self, zmq_context): http_request_address = f"http://{self.request_address}/v1/completions" @@ -1801,7 +1803,7 @@ class MoRIIOConnectorWorker: self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( msgpack.loads(buf) ) - + self.remote_moriio_metadata[expected_engine_id]=metadata setup_agent_time = time.perf_counter() logger.debug( "MoRIIO handshake: add agent took: %s", @@ -2225,6 +2227,7 @@ class MoRIIOConnectorWorker: layer_name: str, local_block_ids: list[int], remote_block_ids: list[int], + remote_moriio_meta: MoRIIOAgentMetadata, ) -> tuple[list[int], list[int], list[int]]: """Compute transfer offsets for block data. @@ -2232,7 +2235,7 @@ class MoRIIOConnectorWorker: layer_name: Name of the layer to transfer local_block_ids: IDs of local blocks remote_block_ids: IDs of remote blocks - + remote_moriio_meta: Metadata of the remote MoRIIO agent Returns: Tuple of (local_offsets, remote_offsets, transfer_sizes) """ @@ -2246,8 +2249,9 @@ class MoRIIOConnectorWorker: block_stride = stride[0] else: _, blknum, blksize, hn, hs = self.kv_cache_shape - ktov_stride = stride[0] + local_ktov_stride = stride[0] block_stride = stride[1] + remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks transfer_size_byte = blksize * hn * hs * sz per_block = 1 if is_mla else 2 @@ -2265,8 +2269,11 @@ class MoRIIOConnectorWorker: w += 1 if not is_mla: # V - offset_local[w] = sz * (1 * ktov_stride + lb * block_stride) - offset_remote[w] = sz * (1 * ktov_stride + rb * block_stride) + # Handle num_block variations originating from PD (different kv strides) + # TODO: address block_sz differences in heterogeneous TP scenarios + # In MLA, we don't need to consider these two cases. + offset_local[w] = sz * (1 * local_ktov_stride + lb * block_stride) + offset_remote[w] = sz * (1 * remote_ktov_stride + rb * block_stride) w += 1 merged_l, merged_r, merged_s = self.merge_contiguous_blocks( @@ -2287,11 +2294,11 @@ class MoRIIOConnectorWorker: return dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0) - sessions = self._get_built_session(dp0_engine_id) + sessions, remote_moriio_meta = self._get_built_session(dp0_engine_id) first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] offs = self._compute_block_transfer_offsets( - first_layer, local_block_ids, remote_block_ids + first_layer, local_block_ids, remote_block_ids, remote_moriio_meta ) for layer_name in self.layer_name_to_local_kv_cache_metadata: From e52ba22429a00e80daf5c97611d3353a6ccb6baf Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 2 Dec 2025 06:44:36 +0000 Subject: [PATCH 49/62] format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 7c566bc8c30ba..e0e60683960ae 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -138,6 +138,7 @@ class ROLE(Enum): CONSUMER = "consumer" NOTINIT = "notinit" + class MoRIIOAgentMetadata( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -437,7 +438,9 @@ class MoRIIOWriter: ) # Get or create sessions - sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id) + sessions, remote_moriio_meta = self.worker._get_built_session( + task.dst_engine_id + ) # Prepare transfer plan plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) @@ -449,7 +452,10 @@ class MoRIIOWriter: self._finalize_if_complete(task, request_info) def _prepare_transfer_plan( - self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata + self, + task: WriteTask, + request_info: RemoteAllocInfo, + remote_moriio_meta: MoRIIOAgentMetadata, ) -> LayerTransferPlan: """Prepare the transfer plan for a layer. @@ -463,7 +469,10 @@ class MoRIIOWriter: # Compute offsets if not cached if request_info.transfer_offset is None: offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta + task.layer_name, + task.local_block_ids, + request_info.block_ids, + remote_moriio_meta, ) request_info.transfer_offset = offsets @@ -827,8 +836,6 @@ class MoRIIOWrapper: self.paths.clear() - - @dataclass class ReqMeta: """Metadata for a single request.""" @@ -1603,7 +1610,9 @@ class MoRIIOConnectorWorker: ) ) self.built_write_session[remote_engine_id] = cur_remote_engine_sessions - return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id] + return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[ + remote_engine_id + ] def _ping(self, zmq_context): http_request_address = f"http://{self.request_address}/v1/completions" @@ -1803,7 +1812,7 @@ class MoRIIOConnectorWorker: self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( msgpack.loads(buf) ) - self.remote_moriio_metadata[expected_engine_id]=metadata + self.remote_moriio_metadata[expected_engine_id] = metadata setup_agent_time = time.perf_counter() logger.debug( "MoRIIO handshake: add agent took: %s", @@ -2251,7 +2260,7 @@ class MoRIIOConnectorWorker: _, blknum, blksize, hn, hs = self.kv_cache_shape local_ktov_stride = stride[0] block_stride = stride[1] - remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks + remote_ktov_stride = block_stride * remote_moriio_meta.num_blocks transfer_size_byte = blksize * hn * hs * sz per_block = 1 if is_mla else 2 From f98cde1997657c976bc14921484bb41cc0ad103a Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 2 Dec 2025 06:48:38 +0000 Subject: [PATCH 50/62] clean up Signed-off-by: inkcherry --- vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index e0e60683960ae..3cad26aeb84d4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -295,7 +295,6 @@ class MoRIIOWriter: Args: worker: Reference to the parent worker """ - # self.worker = worker self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) self._write_task_q: Queue[WriteTask] = Queue() self._write_worker_started = False From bba01338ca866192318b23afbce5bd2ece8f13cf Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 17 Dec 2025 08:36:11 +0000 Subject: [PATCH 51/62] remove merge Signed-off-by: inkcherry --- vllm/entrypoints/openai/serving_chat.py | 1 - vllm/entrypoints/openai/serving_completion.py | 1 - vllm/entrypoints/openai/serving_engine.py | 3 --- 3 files changed, 5 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 37fc5e4a9a9d7..9a7051e0920af 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -335,7 +335,6 @@ class OpenAIServingChat(OpenAIServing): lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, - data_parallel_rank=data_parallel_rank, ) generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6cf000c3e79c1..9681aa8c71e6d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -226,7 +226,6 @@ class OpenAIServingCompletion(OpenAIServing): lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, - data_parallel_rank=data_parallel_rank, ) generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 624b936814b69..d9feee917ff4e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1207,7 +1207,6 @@ class OpenAIServing: lora_request: LoRARequest | None, trace_headers: Mapping[str, str] | None, priority: int, - data_parallel_rank: int | None, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} @@ -1223,7 +1222,6 @@ class OpenAIServing: tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, - data_parallel_rank=data_parallel_rank, ) return engine_request, tokenization_kwargs @@ -1258,7 +1256,6 @@ class OpenAIServing: lora_request=lora_request, trace_headers=trace_headers, priority=priority, - data_parallel_rank=None, ) generator = self.engine_client.generate( From 0b0c33d59ee576f335a61eec4e1f837969f01008 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 22 Dec 2025 02:59:45 +0000 Subject: [PATCH 52/62] fix comments Signed-off-by: inkcherry --- .../moriio_toy_proxy_server.py | 16 ++++++++++++---- .../kv_connector/v1/moriio_connector.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py index 67ef2b02c76a9..98481787268c6 100644 --- a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py +++ b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py @@ -195,8 +195,9 @@ def example_round_robin_dp_loader(request_number, dp_size): @app.route("/v1/chat/completions", methods=["POST"]) async def handle_request(): try: - global request_nums - request_nums += 1 + with _list_lock: + global request_nums + request_nums += 1 def extract_ip_port_fast(url): match = IP_PORT_PATTERN.search(url) @@ -210,6 +211,10 @@ async def handle_request(): prefill_instance_endpoint = None decode_instance_endpoint = None + if not prefill_instances or not decode_instances: + return await make_response( + ("Service Unavailable: No prefill or decode instances are registered.", + 503)) pid = request_nums % len(prefill_instances) did = request_nums % len(decode_instances) prefill_instance_endpoint = prefill_instances[pid] @@ -291,8 +296,11 @@ async def handle_request(): response = await make_response(stream_generator) return response except Exception as e: - print(e) - pass + logger.exception("An error occurred while handling the request: %s", e) + return await make_response(( + f"Internal Server Error: {e!s}", + 500, + )) if __name__ == "__main__": diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 3cad26aeb84d4..1ce1f435cf544 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -71,7 +71,7 @@ class MoRIIOConstants: COMPLETION_PREFIX = "cmpl" PING_INTERVAL = 5 - MAX_PING_RETRIES = 1000000 + MAX_PING_RETRIES = 100 DEFAULT_HANDSHAKE_PORT = "6301" DEFAULT_NOTIFY_PORT = "61005" From 66e233268229afee03fe9dcfa7a14a2a925238ba Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 22 Dec 2025 04:55:29 +0000 Subject: [PATCH 53/62] split large file Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/factory.py | 2 +- .../kv_connector/v1/moriio/__init__.py | 0 .../kv_connector/v1/moriio/moriio_common.py | 321 +++++++ .../v1/{ => moriio}/moriio_connector.py | 878 +----------------- .../kv_connector/v1/moriio/moriio_engine.py | 607 ++++++++++++ 5 files changed, 954 insertions(+), 854 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py rename vllm/distributed/kv_transfer/kv_connector/v1/{ => moriio}/moriio_connector.py (67%) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 954a5153ff67d..fd3d1e76d2450 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -181,7 +181,7 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MoRIIOConnector", - "vllm.distributed.kv_transfer.kv_connector.v1.moriio_connector", + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector", "MoRIIOConnector", ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py new file mode 100644 index 0000000000000..026e7faf57616 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import threading +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.utils.network_utils import ( + get_ip, + get_open_port, + make_zmq_socket, +) + +if TYPE_CHECKING: + pass + +from dataclasses import field +from enum import Enum + +logger = init_logger(__name__) + + +Transfer = tuple[int, float] +EngineId = str +ReqId = str + + +@dataclass +class WriteTask: + request_id: str + dst_engine_id: str + local_block_ids: list[int] + remote_block_ids_hint: list[int] | None + layer_name: str + event: torch.cuda.Event + remote_notify_port: int + remote_ip: str + enqueue_time: float = field(default_factory=time.perf_counter) + retried: int = 0 + + +@dataclass +class LayerTransferPlan: + """Plan for transferring a single layer.""" + + request_id: str + layer_name: str + sess_idx: int + transfer_local_offsets: list[int] + transfer_remote_offsets: list[int] + transfer_sizes: list[int] + use_batch: bool = True + + +@dataclass +class RemoteAllocInfo: + """Information about remote block allocation.""" + + block_ids: list[int] + writes_done: int = 0 + decode_dp_rank: int = 0 + transfer_offset: tuple[list[int], list[int], list[int]] | None = None + + +class ROLE(Enum): + PRODUCER = "producer" + CONSUMER = "consumer" + NOTINIT = "notinit" + + +class MoRIIOAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True, +): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + block_len: int + attn_backend_name: str + + +class RoleManager: + """Manages role state across the connector.""" + + _instance: Optional["RoleManager"] = None + _lock = threading.Lock() + + def __init__(self) -> None: + self._role: ROLE = ROLE.NOTINIT + + @classmethod + def get_instance(cls) -> "RoleManager": + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def set_role(self, role: ROLE) -> None: + """Set the current role.""" + with self._lock: + self._role = role + + def get_role(self) -> ROLE: + """Get the current role.""" + return self._role + + +def set_role(role: ROLE): + """Set the global role.""" + RoleManager.get_instance().set_role(role) + + +def get_role() -> ROLE: + """Get the global role.""" + return RoleManager.get_instance().get_role() + + +class MoRIIOMode(Enum): + READ = "read" + WRITE = "write" + + +class MoRIIOError(Exception): + """Base exception for MoRIIO operations.""" + + pass + + +class HandshakeError(MoRIIOError): + """Exception raised when handshake fails.""" + + pass + + +class TransferError(MoRIIOError): + """Exception raised when transfer fails.""" + + pass + + +def get_moriio_mode() -> MoRIIOMode: + read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE + logger.debug("MoRIIO Connector read_mode: %s", read_mode) + if read_mode: + return MoRIIOMode.READ + else: + return MoRIIOMode.WRITE + + +def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: + return (dp_rank) * tp_size + tp_rank + + +@dataclass +class MoRIIOConfig: + local_ip: str + local_kv_port: int + proxy_ip: str + local_ping_port: int + proxy_ping_port: int + http_port: int + handshake_port: int + notify_port: int + tp_rank: int + dp_rank: int + dp_size: int + tp_size: int + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": + # Port Configuration: + # local_ping_port -> Outgoing heartbeat to proxy + # proxy_ping_port -> Remote proxy's heartbeat ingress port + # http_port -> Instance's HTTP service endpoint + # local_kv_port -> service port for mori engine + # notify_port -> For synchronizing stages between prefill and decode + # handshake_port -> For initial handshake between mori engine + + # TODO : merge notify_port and handshake_port to simplify port management + # supports non-contiguous ports + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + kv_transfer_config = vllm_config.kv_transfer_config + extra_config = kv_transfer_config.kv_connector_extra_config + tp_rank = get_tensor_model_parallel_rank() + dp_rank = vllm_config.parallel_config.data_parallel_rank + base_notify_port = int(extra_config["notify_port"]) + dp_size = vllm_config.parallel_config.data_parallel_size + tp_size = get_tensor_model_parallel_world_size() + port_offset = get_port_offset(dp_rank, tp_rank) + + return cls( + local_ip=get_ip(), + local_kv_port=get_open_port(), + proxy_ip=extra_config["proxy_ip"], + local_ping_port=get_open_port(), + proxy_ping_port=int(extra_config["proxy_ping_port"]), + http_port=int(extra_config["http_port"]), + handshake_port=int(extra_config["handshake_port"]), + notify_port=base_notify_port + port_offset, + tp_rank=tp_rank, + dp_rank=dp_rank, + dp_size=dp_size, + tp_size=tp_size, + ) + + +class MoRIIOConstants: + """Constants for MoRIIO connector.""" + + # ZMQ message types + GET_META_MSG = b"get_meta_msg" + POP_DONE_RECV = b"pop_done_recv" + OVER = b"OVER" + COMPLETION_PREFIX = "cmpl" + + PING_INTERVAL = 5 + MAX_PING_RETRIES = 100 + DEFAULT_HANDSHAKE_PORT = "6301" + DEFAULT_NOTIFY_PORT = "61005" + + VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 + + +@dataclass +class ReqMeta: + """Metadata for a single request.""" + + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_handshake_port: int + remote_notify_port: int + remote_engine_id: str + tp_size: int + remote_dp_size: int + + +class MoRIIOConnectorMetadata(KVConnectorMetadata): + def __init__(self): + self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_save: dict[ReqId, ReqMeta] = {} + self.reqs_to_send: dict[ReqId, float] = {} + + def __repr__(self): + return_str = "" + for req_id, req_meta in self.reqs_to_recv.items(): + return_str += ( + f"{req_id = },{req_meta.local_block_ids = }," + f"{req_meta.remote_host = },{req_meta.remote_port = }" + f"{req_meta.remote_engine_id = },{req_meta.tp_size = }" + ) + return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," + + for req_id, expiry in self.reqs_to_send.items(): + return_str += f"{req_id = },{expiry = }" + return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str}," + return return_str + + def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + write_mode=False, + ): + _req = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_handshake_port=kv_transfer_params["remote_handshake_port"], + remote_notify_port=kv_transfer_params["remote_notify_port"], + tp_size=kv_transfer_params.get("tp_size", 1), + remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), + ) + if write_mode: + self.reqs_to_save[request_id] = _req + else: + self.reqs_to_recv[request_id] = _req + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: zmq.Context | None = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py similarity index 67% rename from vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py rename to vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 1ce1f435cf544..4b6bd906d5d44 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -1,17 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib import logging import math import queue import threading import time from collections import defaultdict -from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional -from weakref import ref as weakref_ref import msgpack import msgspec @@ -19,7 +15,6 @@ import numpy as np import torch import zmq -from vllm import envs from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -27,8 +22,29 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + ROLE, + EngineId, + HandshakeError, + MoRIIOAgentMetadata, + MoRIIOConfig, + MoRIIOConnectorMetadata, + MoRIIOConstants, + MoRIIOMode, + ReqId, + ReqMeta, + WriteTask, + get_moriio_mode, + get_port_offset, + get_role, + set_role, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine import ( + MoRIIOWrapper, + MoRIIOWriter, +) from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, get_world_group, @@ -37,7 +53,6 @@ from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.utils.network_utils import ( get_ip, - get_open_port, make_zmq_path, make_zmq_socket, ) @@ -50,43 +65,13 @@ if TYPE_CHECKING: from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -from dataclasses import field -from enum import Enum -from queue import Empty, Queue - logger = init_logger(__name__) -Transfer = tuple[int, float] -EngineId = str -ReqId = str - - -class MoRIIOConstants: - """Constants for MoRIIO connector.""" - - # ZMQ message types - GET_META_MSG = b"get_meta_msg" - POP_DONE_RECV = b"pop_done_recv" - OVER = b"OVER" - COMPLETION_PREFIX = "cmpl" - - PING_INTERVAL = 5 - MAX_PING_RETRIES = 100 - DEFAULT_HANDSHAKE_PORT = "6301" - DEFAULT_NOTIFY_PORT = "61005" - - VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 - - try: from mori.io import ( BackendType, - EngineDesc, IOEngine, IOEngineConfig, - MemoryDesc, - PollCqMode, - RdmaBackendConfig, ) logger.info("MoRIIO is available") @@ -96,803 +81,8 @@ except ImportError: MoRIIO_enabled = False -@dataclass -class WriteTask: - request_id: str - dst_engine_id: str - local_block_ids: list[int] - remote_block_ids_hint: list[int] | None - layer_name: str - event: torch.cuda.Event - remote_notify_port: int - remote_ip: str - enqueue_time: float = field(default_factory=time.perf_counter) - retried: int = 0 - - -@dataclass -class LayerTransferPlan: - """Plan for transferring a single layer.""" - - request_id: str - layer_name: str - sess_idx: int - transfer_local_offsets: list[int] - transfer_remote_offsets: list[int] - transfer_sizes: list[int] - use_batch: bool = True - - -@dataclass -class RemoteAllocInfo: - """Information about remote block allocation.""" - - block_ids: list[int] - writes_done: int = 0 - decode_dp_rank: int = 0 - transfer_offset: tuple[list[int], list[int], list[int]] | None = None - - -class ROLE(Enum): - PRODUCER = "producer" - CONSUMER = "consumer" - NOTINIT = "notinit" - - -class MoRIIOAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property.d - dict=True, -): - engine_id: str - agent_metadata: bytes - kv_caches_base_addr: list[int] - num_blocks: int - block_len: int - attn_backend_name: str - - -class RoleManager: - """Manages role state across the connector.""" - - _instance: Optional["RoleManager"] = None - _lock = threading.Lock() - - def __init__(self) -> None: - self._role: ROLE = ROLE.NOTINIT - - @classmethod - def get_instance(cls) -> "RoleManager": - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance - - def set_role(self, role: ROLE) -> None: - """Set the current role.""" - with self._lock: - self._role = role - - def get_role(self) -> ROLE: - """Get the current role.""" - return self._role - - -def set_role(role: ROLE): - """Set the global role.""" - RoleManager.get_instance().set_role(role) - - -def get_role() -> ROLE: - """Get the global role.""" - return RoleManager.get_instance().get_role() - - -class MoRIIOMode(Enum): - READ = "read" - WRITE = "write" - - -class MoRIIOError(Exception): - """Base exception for MoRIIO operations.""" - - pass - - -class HandshakeError(MoRIIOError): - """Exception raised when handshake fails.""" - - pass - - -class TransferError(MoRIIOError): - """Exception raised when transfer fails.""" - - pass - - -def get_moriio_mode() -> MoRIIOMode: - read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE - logger.debug("MoRIIO Connector read_mode: %s", read_mode) - if read_mode: - return MoRIIOMode.READ - else: - return MoRIIOMode.WRITE - - -def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: - return (dp_rank) * tp_size + tp_rank - - -@dataclass -class MoRIIOConfig: - local_ip: str - local_kv_port: int - proxy_ip: str - local_ping_port: int - proxy_ping_port: int - http_port: int - handshake_port: int - notify_port: int - tp_rank: int - dp_rank: int - dp_size: int - tp_size: int - - @classmethod - def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": - # Port Configuration: - # local_ping_port -> Outgoing heartbeat to proxy - # proxy_ping_port -> Remote proxy's heartbeat ingress port - # http_port -> Instance's HTTP service endpoint - # local_kv_port -> service port for mori engine - # notify_port -> For synchronizing stages between prefill and decode - # handshake_port -> For initial handshake between mori engine - - # TODO : merge notify_port and handshake_port to simplify port management - # supports non-contiguous ports - assert vllm_config.kv_transfer_config is not None, ( - "kv_transfer_config must be set for MoRIIOConnector" - ) - kv_transfer_config = vllm_config.kv_transfer_config - extra_config = kv_transfer_config.kv_connector_extra_config - tp_rank = get_tensor_model_parallel_rank() - dp_rank = vllm_config.parallel_config.data_parallel_rank - base_notify_port = int(extra_config["notify_port"]) - dp_size = vllm_config.parallel_config.data_parallel_size - tp_size = get_tensor_model_parallel_world_size() - port_offset = get_port_offset(dp_rank, tp_rank) - - return cls( - local_ip=get_ip(), - local_kv_port=get_open_port(), - proxy_ip=extra_config["proxy_ip"], - local_ping_port=get_open_port(), - proxy_ping_port=int(extra_config["proxy_ping_port"]), - http_port=int(extra_config["http_port"]), - handshake_port=int(extra_config["handshake_port"]), - notify_port=base_notify_port + port_offset, - tp_rank=tp_rank, - dp_rank=dp_rank, - dp_size=dp_size, - tp_size=tp_size, - ) - - -"""Write task execution logic for MoRIIO connector.""" - - -class MoRIIOWriter: - """Handles write operations for KV cache transfers. - Implements distributed KV cache transfer using the MoRIIO library - for RDMA-based communication between prefill and decode instances.""" - - def __init__(self, worker: "MoRIIOConnectorWorker"): - """Initialize the writer. - - Args: - worker: Reference to the parent worker - """ - self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) - self._write_task_q: Queue[WriteTask] = Queue() - self._write_worker_started = False - self._write_worker_lock = threading.Lock() - self._deferred_tasks: list[WriteTask] = [] - - @property - def worker(self) -> "MoRIIOConnectorWorker": - """Get the worker instance. - - Returns: - The parent worker instance - - Raises: - RuntimeError: If worker has been garbage collected - """ - worker = self._worker_ref() - if worker is None: - raise RuntimeError("Parent worker has been garbage collected") - return worker - - def ensure_worker_started(self) -> None: - """Ensure the background write worker is running.""" - if self._write_worker_started: - return - self._write_worker_started = True - with self._write_worker_lock: - thread = threading.Thread( - target=self._write_worker_loop, daemon=True, name="moriio-write-worker" - ) - thread.start() - logger.info("Started MoRIIO write worker thread") - - def schedule_write(self, task: WriteTask) -> None: - """Schedule a write task. - - Args: - task: The write task to schedule - """ - self.ensure_worker_started() - self._write_task_q.put(task) - - def _write_worker_loop(self) -> None: - """Main loop for the write worker thread.""" - - while True: - # Process deferred tasks first - self._process_deferred_tasks() - - # Get new task - try: - task = self._write_task_q.get(timeout=0.01) - except Empty: - continue - - # Check if remote blocks are ready - if not self._is_remote_ready(task): - # task.retry_count += 1 - self._deferred_tasks.append(task) - # logger.debug( - # "Deferred task for request %s (retry %d)", - # task.request_id, task.retry_count - # ) - continue - - # Execute the task - - self._execute_write_task(task) - - def _process_deferred_tasks(self) -> None: - """Process tasks that were previously deferred.""" - if not self._deferred_tasks: - return - - still_deferred: list[WriteTask] = [] - for task in self._deferred_tasks: - if self._is_remote_ready(task): - self._execute_write_task(task) - else: - still_deferred.append(task) - - self._deferred_tasks = still_deferred - - def _is_remote_ready(self, task: WriteTask) -> bool: - """Check if remote blocks are allocated for this task. - - Args: - task: The write task - - Returns: - True if remote blocks are ready - """ - return ( - task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict - ) - - def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: - """Get remote allocation info for a request. - - Args: - request_id: The request ID - - Returns: - Remote allocation information - - Raises: - KeyError: If allocation info is missing - """ - try: - return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id] - except KeyError as e: - raise KeyError( - f"Remote allocation info missing for request {request_id}" - ) from e - - def _execute_write_task(self, task: WriteTask) -> None: - """Execute a single write task. - - Args: - task: The write task to execute - - """ - # Get remote allocation info - request_info = self._get_remote_alloc_info(task.request_id) - - if request_info.block_ids is None: - logger.debug("Request %s remote block IDs not ready", task.request_id) - return - - # Wait for CUDA event - # The attention computation of the current layer cannot - # overlap with the kv transfer task, - # otherwise it will cause precision issues. - # This event is used to synchronize the kv transfer and computation tasks. - task.event.synchronize() - - # Update engine ID with DP rank - task.dst_engine_id = self.worker.get_engine_name_with_dp( - task.dst_engine_id, request_info.decode_dp_rank - ) - - # Get or create sessions - sessions, remote_moriio_meta = self.worker._get_built_session( - task.dst_engine_id - ) - - # Prepare transfer plan - plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) - - # Execute transfer - self._do_layer_write(plan, sessions) - - # Finalize if all layers complete - self._finalize_if_complete(task, request_info) - - def _prepare_transfer_plan( - self, - task: WriteTask, - request_info: RemoteAllocInfo, - remote_moriio_meta: MoRIIOAgentMetadata, - ) -> LayerTransferPlan: - """Prepare the transfer plan for a layer. - - Args: - task: The write task - request_info: Remote allocation information - - Returns: - The transfer plan - """ - # Compute offsets if not cached - if request_info.transfer_offset is None: - offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, - task.local_block_ids, - request_info.block_ids, - remote_moriio_meta, - ) - request_info.transfer_offset = offsets - - # Get session index - layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys()) - sess_idx = layer_names.index(task.layer_name) - - local_off, remote_off, sizes = request_info.transfer_offset - - return LayerTransferPlan( - request_id=task.request_id, - layer_name=task.layer_name, - sess_idx=sess_idx, - transfer_local_offsets=local_off, - transfer_remote_offsets=remote_off, - transfer_sizes=sizes, - use_batch=True, - ) - - def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None: - """Perform the actual layer write. - - Args: - plan: The transfer plan - sessions: List of transfer sessions - """ - if plan.use_batch: - self.worker.moriio_wrapper.write_remote_data( - plan.transfer_sizes, - plan.transfer_local_offsets, - plan.transfer_remote_offsets, - sessions[plan.sess_idx], - ) - else: - for i in range(len(plan.transfer_local_offsets)): - self.worker.moriio_wrapper.write_remote_data_single( - plan.transfer_sizes[i], - plan.transfer_local_offsets[i], - plan.transfer_remote_offsets[i], - plan.sess_idx, - ) - - def _finalize_if_complete( - self, task: WriteTask, request_info: RemoteAllocInfo - ) -> None: - """Finalize transfer if all layers are complete. - - Args: - task: The write task - request_info: Remote allocation information - """ - request_info.writes_done += 1 - - if request_info.writes_done >= self.worker.num_layers: - # Wait for transfer to complete - self.worker.moriio_wrapper.waiting_for_transfer_complete() - - remote_port = task.remote_notify_port + get_port_offset( - request_info.decode_dp_rank, self.worker.tp_rank - ) - # Consider using RDMA immediate data in decode side - # to eliminate the need for this notification. - # Consider including the first gen token from prefill in the notification - - # Send completion notification - self.worker.moriio_wrapper.send_notify( - task.request_id, task.remote_ip, remote_port - ) - # mark request as done, then we can free the blocks - with self.worker.moriio_wrapper.lock: - self.worker.moriio_wrapper.done_req_ids.append(task.request_id) - del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ - task.request_id - ] - logger.debug( - "Completed transfer for request %s, notified port %d", - task.request_id, - remote_port, - ) - - -class MoRIIOWrapper: - """Wrapper for MoRIIO engine operations. - - Handles both producer and consumer roles for KV cache transfers. - - Args: - moriio_engine: MoRIIO engine instance - tp_rank: Tensor parallel rank - dp_rank: Data parallel rank - """ - - def __init__( - self, - moriio_engine: Optional["IOEngine"] = None, - tp_rank: int = 0, - dp_rank: int = 0, - ): - self.tp_rank = tp_rank - self.dp_rank = dp_rank - self.moriio_engine = moriio_engine - self.remote_memory_metadata = None - self.local_memory_registered = False - self.local_memory_metadata = None - self.transfer_status: list[Any] = [] - self.remote_engine_ip: str | None = None - self.notify_port: int | None = None - self.lock = threading.Lock() - self.done_req_ids: list[str] = [] - self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} - self.done_write_cache_req_ids: list[str] = [] - self.notify_thread: threading.Thread | None = None - self.sessions: list[IOEngine.Session] = [] - self.paths: dict[str, zmq.Socket] = {} - - def set_moriio_engine(self, moriio_engine): - assert moriio_engine is not None, ( - "You Cannot pass None engine to MoRIIOWrapper!" - ) - self.moriio_engine = moriio_engine - - def set_backend_type(self, backend_type): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER - post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE - num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS - poll_mode = PollCqMode.POLLING - rdma_cfg = RdmaBackendConfig( - qp_per_transfer, - post_batch_size, - num_worker_threads, - poll_mode, - ) - self.moriio_engine.create_backend(backend_type, rdma_cfg) - - def get_agent_metadata(self): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - engine_metadata = self.moriio_engine.get_engine_desc() - engine_metadata_packed = engine_metadata.pack() - return engine_metadata_packed - - def register_remote_engine(self, remote_packed_engine_metadata): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) - self.moriio_engine.register_remote_engine(consumer_engine_metadata) - return consumer_engine_metadata.key - - def register_local_tensor(self, tensor: torch.Tensor): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - try: - self.local_memory_metadata = self.moriio_engine.register_torch_tensor( - tensor - ) - assert self.local_memory_metadata is not None, ( - "register_torch_tensor returned None" - ) - local_memory_metadata_packed = self.local_memory_metadata.pack() - except Exception as e: - raise MoRIIOError(f"Failed to register local memory: {e}") from e - self.local_memory_registered = True - return local_memory_metadata_packed - - def get_unpack_memory_metadata(self, packed_memory_metadata): - return MemoryDesc.unpack(packed_memory_metadata) - - def build_session(self, local_memory_metadata, remote_memory_metadata): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - return self.moriio_engine.create_session( - local_memory_metadata, remote_memory_metadata - ) - - def read_remote_data( - self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): - assert self.local_memory_registered, "You have not register local memory data!" - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - transfer_status = session.batch_read( - local_offset, - remote_offset, - transfer_size_byte, - self.moriio_engine.allocate_transfer_uid(), - ) - - return transfer_status - - def write_remote_data( - self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): - assert self.local_memory_registered, "You have not register local memory data!" - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - write_uid = self.moriio_engine.allocate_transfer_uid() - - transfer_status = session.batch_write( - local_offset, remote_offset, transfer_size_byte, write_uid - ) - with self.lock: - self.transfer_status.append(transfer_status) - - def write_remote_data_single( - self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 - ): - assert self.local_memory_registered, "You have not register local memory data!" - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - transfer_status = self.sessions[sess_idx].write( - local_offset, - remote_offset, - transfer_size_byte, - self.moriio_engine.allocate_transfer_uid(), - ) - with self.lock: - self.transfer_status.append(transfer_status) - - def waiting_for_transfer_complete(self): - if not self.transfer_status: - return - - transfers_to_wait = [] - with self.lock: - transfers_to_wait = self.transfer_status[:] - self.transfer_status.clear() - - for status in transfers_to_wait: - try: - status.Wait() - if not status.Succeeded(): - logger.error( - "Transfer failed: %s, Code: %s", status.Message(), status.Code() - ) - raise TransferError("MoRIIO transfer failed!") - except Exception as e: - logger.error("Transfer %s failed: %s", status, e) - raise - - def async_wait_reqid(self): - assert self.notify_port is not None, "Notify port cannot be None" - - if self.notify_thread is not None: - return - - def _async_wait(): - host = "*" - path = make_zmq_path("tcp", host, self.notify_port) - logger.info("Node starting to listen notify from path = %s", path) - - with zmq_ctx(zmq.ROUTER, path) as sock: - while True: - try: - identity, msg = sock.recv_multipart() - self._handle_message(msg) - except Exception as e: - logger.error("Error processing message: %s", e) - raise HandshakeError(f"Error processing message: {e}") from e - - self.notify_thread = threading.Thread( - target=_async_wait, daemon=True, name="moriio-notify-listener" - ) - self.notify_thread.start() - - def _handle_message(self, msg: bytes): - """Handles incoming messages from remote nodes.""" - # Handles incoming remote messages: - # Prefill Role: - # [write] mode: receives block information (allocation) - # [read] mode: receives block release messages from decode side - # Decode Role: - # [write] mode: receives KV cache write completion notifications - handled = False - try: - data = msgpack.loads(msg) - if isinstance(data, dict) and "req_id" in data: - self._handle_structured_message(data) - - return - except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): - logger.debug("Failed to decode msgpack message, will try as string") - pass - - try: - msg_str = msg.decode("UTF-8") - if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX): - self._handle_completion_message(msg_str) - handled = True - except UnicodeDecodeError: - logger.warning("Received non-UTF8 message: %s", msg_str) - if not handled: - raise MoRIIOError(f"Unhandled message format: {msg_str}") - - def _handle_structured_message(self, data: dict): - assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages" - req_id = data["req_id"] - block_notify_list = data.get("block_notify_list", []) - decode_dp_rank = data.get("decode_rank", 0) - assert len(block_notify_list) > 0, ( - "block_notify_list cannot be empty in remote allocate message" - ) - - with self.lock: - self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo( - block_ids=block_notify_list, decode_dp_rank=decode_dp_rank - ) - - def _handle_completion_message(self, msg: str): - with self.lock: - if get_role() == ROLE.PRODUCER: - self.done_req_ids.append(msg) - else: - self.done_write_cache_req_ids.append(msg) - - def send_notify(self, req_ids, remote_ip, remote_port): - if not remote_ip or not remote_port: - logger.warning("Missing remote_ip or remote_port for notification") - return - - path = make_zmq_path("tcp", remote_ip, remote_port) - - if path not in self.paths: - ctx = zmq.Context.instance() - sock = make_zmq_socket( - ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False - ) - self.paths[path] = sock - - req_list = req_ids if isinstance(req_ids, list) else [req_ids] - - sock = self.paths[path] - try: - for req_id in req_list: - if not isinstance(req_id, str): - logger.warning( - "Invalid req_id type: %s, expected str", type(req_id) - ) - continue - sock.send(req_id.encode("utf-8")) - except Exception as e: - logger.error("Failed to send notification to %s: %s", path, e) - self.paths.pop(path, None) - raise - - def pop_finished_req_ids(self): - # producer invocation: get the set of completed requests at the decode - with self.lock: - done_send = set(self.done_req_ids) - self.done_req_ids = [] - return done_send - - def pop_finished_write_req_ids(self): - # Call the consumer in write mode to get the collection after write completion - with self.lock: - done_write_cache = set(self.done_write_cache_req_ids) - self.done_write_cache_req_ids = [] - return done_write_cache - - def shutdown(self): - logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") - for path, sock in self.paths.items(): - try: - sock.close(linger=0) - logger.debug("Closed ZMQ socket for path: %s", path) - except Exception as e: - logger.warning("Error closing ZMQ socket for path %s: %s", path, e) - self.paths.clear() - - -@dataclass -class ReqMeta: - """Metadata for a single request.""" - - local_block_ids: list[int] - remote_block_ids: list[int] - remote_host: str - remote_port: int - remote_handshake_port: int - remote_notify_port: int - remote_engine_id: str - tp_size: int - remote_dp_size: int - - -class MoRIIOConnectorMetadata(KVConnectorMetadata): - def __init__(self): - self.reqs_to_recv: dict[ReqId, ReqMeta] = {} - self.reqs_to_save: dict[ReqId, ReqMeta] = {} - self.reqs_to_send: dict[ReqId, float] = {} - - def __repr__(self): - return_str = "" - for req_id, req_meta in self.reqs_to_recv.items(): - return_str += ( - f"{req_id = },{req_meta.local_block_ids = }," - f"{req_meta.remote_host = },{req_meta.remote_port = }" - f"{req_meta.remote_engine_id = },{req_meta.tp_size = }" - ) - return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," - - for req_id, expiry in self.reqs_to_send.items(): - return_str += f"{req_id = },{expiry = }" - return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str}," - return return_str - - def add_new_req( - self, - request_id: ReqId, - local_block_ids: list[int], - kv_transfer_params: dict[str, Any], - write_mode=False, - ): - _req = ReqMeta( - local_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params["remote_block_ids"], - remote_engine_id=kv_transfer_params["remote_engine_id"], - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], - remote_handshake_port=kv_transfer_params["remote_handshake_port"], - remote_notify_port=kv_transfer_params["remote_notify_port"], - tp_size=kv_transfer_params.get("tp_size", 1), - remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), - ) - if write_mode: - self.reqs_to_save[request_id] = _req - else: - self.reqs_to_recv[request_id] = _req +def is_moriio_available() -> bool: + return MoRIIO_enabled class MoRIIOConnector(KVConnectorBase_V1): @@ -1371,7 +561,7 @@ class MoRIIOConnectorWorker: """Implementation of Worker side methods""" def __init__(self, vllm_config: VllmConfig, engine_id: str): - if not MoRIIO_enabled: + if not is_moriio_available(): raise RuntimeError( "MoRIIO is not available. Please ensure the 'mori' package " "is installed and properly configured." @@ -2323,21 +1513,3 @@ class MoRIIOConnectorWorker: remote_host, str(remote_notify_port + self.tp_rank), ) - - -@contextlib.contextmanager -def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: - """Context manager for a ZMQ socket""" - - if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): - raise ValueError(f"Unexpected socket type: {socket_type}") - - ctx: zmq.Context | None = None - try: - ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket( - ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER - ) - finally: - if ctx is not None: - ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py new file mode 100644 index 0000000000000..4357c0335ef99 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py @@ -0,0 +1,607 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import TYPE_CHECKING, Any, Optional +from weakref import ref as weakref_ref + +import msgpack +import torch +import zmq + +from vllm import envs +from vllm.logger import init_logger +from vllm.utils.network_utils import ( + make_zmq_path, + make_zmq_socket, +) + +if TYPE_CHECKING: + pass + +from queue import Empty, Queue + +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + ROLE, + HandshakeError, + LayerTransferPlan, + MoRIIOAgentMetadata, + MoRIIOConstants, + MoRIIOError, + RemoteAllocInfo, + TransferError, + WriteTask, + get_port_offset, + get_role, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + MoRIIOConnectorWorker, +) + +logger = init_logger(__name__) +try: + from mori.io import ( + EngineDesc, + IOEngine, + MemoryDesc, + PollCqMode, + RdmaBackendConfig, + ) + + logger.info("MoRIIO is available") +except ImportError: + logger.error("MoRIIO is not available") + + +"""Write task execution logic for MoRIIO connector.""" + + +class MoRIIOWriter: + """Handles write operations for KV cache transfers. + Implements distributed KV cache transfer using the MoRIIO library + for RDMA-based communication between prefill and decode instances.""" + + def __init__(self, worker: "MoRIIOConnectorWorker"): + """Initialize the writer. + + Args: + worker: Reference to the parent worker + """ + self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) + self._write_task_q: Queue[WriteTask] = Queue() + self._write_worker_started = False + self._write_worker_lock = threading.Lock() + self._deferred_tasks: list[WriteTask] = [] + + @property + def worker(self) -> "MoRIIOConnectorWorker": + """Get the worker instance. + + Returns: + The parent worker instance + + Raises: + RuntimeError: If worker has been garbage collected + """ + worker = self._worker_ref() + if worker is None: + raise RuntimeError("Parent worker has been garbage collected") + return worker + + def ensure_worker_started(self) -> None: + """Ensure the background write worker is running.""" + if self._write_worker_started: + return + self._write_worker_started = True + with self._write_worker_lock: + thread = threading.Thread( + target=self._write_worker_loop, daemon=True, name="moriio-write-worker" + ) + thread.start() + logger.info("Started MoRIIO write worker thread") + + def schedule_write(self, task: WriteTask) -> None: + """Schedule a write task. + + Args: + task: The write task to schedule + """ + self.ensure_worker_started() + self._write_task_q.put(task) + + def _write_worker_loop(self) -> None: + """Main loop for the write worker thread.""" + + while True: + # Process deferred tasks first + self._process_deferred_tasks() + + # Get new task + try: + task = self._write_task_q.get(timeout=0.01) + except Empty: + continue + + # Check if remote blocks are ready + if not self._is_remote_ready(task): + # task.retry_count += 1 + self._deferred_tasks.append(task) + # logger.debug( + # "Deferred task for request %s (retry %d)", + # task.request_id, task.retry_count + # ) + continue + + # Execute the task + + self._execute_write_task(task) + + def _process_deferred_tasks(self) -> None: + """Process tasks that were previously deferred.""" + if not self._deferred_tasks: + return + + still_deferred: list[WriteTask] = [] + for task in self._deferred_tasks: + if self._is_remote_ready(task): + self._execute_write_task(task) + else: + still_deferred.append(task) + + self._deferred_tasks = still_deferred + + def _is_remote_ready(self, task: WriteTask) -> bool: + """Check if remote blocks are allocated for this task. + + Args: + task: The write task + + Returns: + True if remote blocks are ready + """ + return ( + task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict + ) + + def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: + """Get remote allocation info for a request. + + Args: + request_id: The request ID + + Returns: + Remote allocation information + + Raises: + KeyError: If allocation info is missing + """ + try: + return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id] + except KeyError as e: + raise KeyError( + f"Remote allocation info missing for request {request_id}" + ) from e + + def _execute_write_task(self, task: WriteTask) -> None: + """Execute a single write task. + + Args: + task: The write task to execute + + """ + # Get remote allocation info + request_info = self._get_remote_alloc_info(task.request_id) + + if request_info.block_ids is None: + logger.debug("Request %s remote block IDs not ready", task.request_id) + return + + # Wait for CUDA event + # The attention computation of the current layer cannot + # overlap with the kv transfer task, + # otherwise it will cause precision issues. + # This event is used to synchronize the kv transfer and computation tasks. + task.event.synchronize() + + # Update engine ID with DP rank + task.dst_engine_id = self.worker.get_engine_name_with_dp( + task.dst_engine_id, request_info.decode_dp_rank + ) + + # Get or create sessions + sessions, remote_moriio_meta = self.worker._get_built_session( + task.dst_engine_id + ) + + # Prepare transfer plan + plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) + + # Execute transfer + self._do_layer_write(plan, sessions) + + # Finalize if all layers complete + self._finalize_if_complete(task, request_info) + + def _prepare_transfer_plan( + self, + task: WriteTask, + request_info: RemoteAllocInfo, + remote_moriio_meta: MoRIIOAgentMetadata, + ) -> LayerTransferPlan: + """Prepare the transfer plan for a layer. + + Args: + task: The write task + request_info: Remote allocation information + + Returns: + The transfer plan + """ + # Compute offsets if not cached + if request_info.transfer_offset is None: + offsets = self.worker._compute_block_transfer_offsets( + task.layer_name, + task.local_block_ids, + request_info.block_ids, + remote_moriio_meta, + ) + request_info.transfer_offset = offsets + + # Get session index + layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys()) + sess_idx = layer_names.index(task.layer_name) + + local_off, remote_off, sizes = request_info.transfer_offset + + return LayerTransferPlan( + request_id=task.request_id, + layer_name=task.layer_name, + sess_idx=sess_idx, + transfer_local_offsets=local_off, + transfer_remote_offsets=remote_off, + transfer_sizes=sizes, + use_batch=True, + ) + + def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None: + """Perform the actual layer write. + + Args: + plan: The transfer plan + sessions: List of transfer sessions + """ + if plan.use_batch: + self.worker.moriio_wrapper.write_remote_data( + plan.transfer_sizes, + plan.transfer_local_offsets, + plan.transfer_remote_offsets, + sessions[plan.sess_idx], + ) + else: + for i in range(len(plan.transfer_local_offsets)): + self.worker.moriio_wrapper.write_remote_data_single( + plan.transfer_sizes[i], + plan.transfer_local_offsets[i], + plan.transfer_remote_offsets[i], + plan.sess_idx, + ) + + def _finalize_if_complete( + self, task: WriteTask, request_info: RemoteAllocInfo + ) -> None: + """Finalize transfer if all layers are complete. + + Args: + task: The write task + request_info: Remote allocation information + """ + request_info.writes_done += 1 + + if request_info.writes_done >= self.worker.num_layers: + # Wait for transfer to complete + self.worker.moriio_wrapper.waiting_for_transfer_complete() + + remote_port = task.remote_notify_port + get_port_offset( + request_info.decode_dp_rank, self.worker.tp_rank + ) + # Consider using RDMA immediate data in decode side + # to eliminate the need for this notification. + # Consider including the first gen token from prefill in the notification + + # Send completion notification + self.worker.moriio_wrapper.send_notify( + task.request_id, task.remote_ip, remote_port + ) + # mark request as done, then we can free the blocks + with self.worker.moriio_wrapper.lock: + self.worker.moriio_wrapper.done_req_ids.append(task.request_id) + del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ + task.request_id + ] + logger.debug( + "Completed transfer for request %s, notified port %d", + task.request_id, + remote_port, + ) + + +class MoRIIOWrapper: + """Wrapper for MoRIIO engine operations. + + Handles both producer and consumer roles for KV cache transfers. + + Args: + moriio_engine: MoRIIO engine instance + tp_rank: Tensor parallel rank + dp_rank: Data parallel rank + """ + + def __init__( + self, + moriio_engine: Optional["IOEngine"] = None, + tp_rank: int = 0, + dp_rank: int = 0, + ): + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moriio_engine = moriio_engine + self.remote_memory_metadata = None + self.local_memory_registered = False + self.local_memory_metadata = None + self.transfer_status: list[Any] = [] + self.remote_engine_ip: str | None = None + self.notify_port: int | None = None + self.lock = threading.Lock() + self.done_req_ids: list[str] = [] + self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} + self.done_write_cache_req_ids: list[str] = [] + self.notify_thread: threading.Thread | None = None + self.sessions: list[IOEngine.Session] = [] + self.paths: dict[str, zmq.Socket] = {} + + def set_moriio_engine(self, moriio_engine): + assert moriio_engine is not None, ( + "You Cannot pass None engine to MoRIIOWrapper!" + ) + self.moriio_engine = moriio_engine + + def set_backend_type(self, backend_type): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER + post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE + num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS + poll_mode = PollCqMode.POLLING + rdma_cfg = RdmaBackendConfig( + qp_per_transfer, + post_batch_size, + num_worker_threads, + poll_mode, + ) + self.moriio_engine.create_backend(backend_type, rdma_cfg) + + def get_agent_metadata(self): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + engine_metadata = self.moriio_engine.get_engine_desc() + engine_metadata_packed = engine_metadata.pack() + return engine_metadata_packed + + def register_remote_engine(self, remote_packed_engine_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) + self.moriio_engine.register_remote_engine(consumer_engine_metadata) + return consumer_engine_metadata.key + + def register_local_tensor(self, tensor: torch.Tensor): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + try: + self.local_memory_metadata = self.moriio_engine.register_torch_tensor( + tensor + ) + assert self.local_memory_metadata is not None, ( + "register_torch_tensor returned None" + ) + local_memory_metadata_packed = self.local_memory_metadata.pack() + except Exception as e: + raise MoRIIOError(f"Failed to register local memory: {e}") from e + self.local_memory_registered = True + return local_memory_metadata_packed + + def get_unpack_memory_metadata(self, packed_memory_metadata): + return MemoryDesc.unpack(packed_memory_metadata) + + def build_session(self, local_memory_metadata, remote_memory_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + return self.moriio_engine.create_session( + local_memory_metadata, remote_memory_metadata + ) + + def read_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + transfer_status = session.batch_read( + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) + + return transfer_status + + def write_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + write_uid = self.moriio_engine.allocate_transfer_uid() + + transfer_status = session.batch_write( + local_offset, remote_offset, transfer_size_byte, write_uid + ) + with self.lock: + self.transfer_status.append(transfer_status) + + def write_remote_data_single( + self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + transfer_status = self.sessions[sess_idx].write( + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) + with self.lock: + self.transfer_status.append(transfer_status) + + def waiting_for_transfer_complete(self): + if not self.transfer_status: + return + + transfers_to_wait = [] + with self.lock: + transfers_to_wait = self.transfer_status[:] + self.transfer_status.clear() + + for status in transfers_to_wait: + try: + status.Wait() + if not status.Succeeded(): + logger.error( + "Transfer failed: %s, Code: %s", status.Message(), status.Code() + ) + raise TransferError("MoRIIO transfer failed!") + except Exception as e: + logger.error("Transfer %s failed: %s", status, e) + raise + + def async_wait_reqid(self): + assert self.notify_port is not None, "Notify port cannot be None" + + if self.notify_thread is not None: + return + + def _async_wait(): + host = "*" + path = make_zmq_path("tcp", host, self.notify_port) + logger.info("Node starting to listen notify from path = %s", path) + + with zmq_ctx(zmq.ROUTER, path) as sock: + while True: + try: + identity, msg = sock.recv_multipart() + self._handle_message(msg) + except Exception as e: + logger.error("Error processing message: %s", e) + raise HandshakeError(f"Error processing message: {e}") from e + + self.notify_thread = threading.Thread( + target=_async_wait, daemon=True, name="moriio-notify-listener" + ) + self.notify_thread.start() + + def _handle_message(self, msg: bytes): + """Handles incoming messages from remote nodes.""" + # Handles incoming remote messages: + # Prefill Role: + # [write] mode: receives block information (allocation) + # [read] mode: receives block release messages from decode side + # Decode Role: + # [write] mode: receives KV cache write completion notifications + handled = False + try: + data = msgpack.loads(msg) + if isinstance(data, dict) and "req_id" in data: + self._handle_structured_message(data) + + return + except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): + logger.debug("Failed to decode msgpack message, will try as string") + pass + + try: + msg_str = msg.decode("UTF-8") + if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX): + self._handle_completion_message(msg_str) + handled = True + except UnicodeDecodeError: + logger.warning("Received non-UTF8 message: %s", msg_str) + if not handled: + raise MoRIIOError(f"Unhandled message format: {msg_str}") + + def _handle_structured_message(self, data: dict): + assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages" + req_id = data["req_id"] + block_notify_list = data.get("block_notify_list", []) + decode_dp_rank = data.get("decode_rank", 0) + assert len(block_notify_list) > 0, ( + "block_notify_list cannot be empty in remote allocate message" + ) + + with self.lock: + self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo( + block_ids=block_notify_list, decode_dp_rank=decode_dp_rank + ) + + def _handle_completion_message(self, msg: str): + with self.lock: + if get_role() == ROLE.PRODUCER: + self.done_req_ids.append(msg) + else: + self.done_write_cache_req_ids.append(msg) + + def send_notify(self, req_ids, remote_ip, remote_port): + if not remote_ip or not remote_port: + logger.warning("Missing remote_ip or remote_port for notification") + return + + path = make_zmq_path("tcp", remote_ip, remote_port) + + if path not in self.paths: + ctx = zmq.Context.instance() + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) + self.paths[path] = sock + + req_list = req_ids if isinstance(req_ids, list) else [req_ids] + + sock = self.paths[path] + try: + for req_id in req_list: + if not isinstance(req_id, str): + logger.warning( + "Invalid req_id type: %s, expected str", type(req_id) + ) + continue + sock.send(req_id.encode("utf-8")) + except Exception as e: + logger.error("Failed to send notification to %s: %s", path, e) + self.paths.pop(path, None) + raise + + def pop_finished_req_ids(self): + # producer invocation: get the set of completed requests at the decode + with self.lock: + done_send = set(self.done_req_ids) + self.done_req_ids = [] + return done_send + + def pop_finished_write_req_ids(self): + # Call the consumer in write mode to get the collection after write completion + with self.lock: + done_write_cache = set(self.done_write_cache_req_ids) + self.done_write_cache_req_ids = [] + return done_write_cache + + def shutdown(self): + logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug("Closed ZMQ socket for path: %s", path) + except Exception as e: + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) + self.paths.clear() From a0330452d5cb6788271cfe8eccbadfd5e43c6e13 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 22 Dec 2025 06:56:33 +0000 Subject: [PATCH 54/62] fix type checking Signed-off-by: inkcherry --- .../kv_transfer/kv_connector/v1/moriio/moriio_engine.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py index 4357c0335ef99..3a35c22622b89 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py @@ -34,9 +34,11 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( get_role, zmq_ctx, ) -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( - MoRIIOConnectorWorker, -) + +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + MoRIIOConnectorWorker, + ) logger = init_logger(__name__) try: From c4dcb3475e0cf3bcfdd56a55c6fb43ee7d0030ff Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 08:10:03 +0000 Subject: [PATCH 55/62] add basic test Signed-off-by: inkcherry --- .../unit/test_moriio_connector.py | 605 ++++++++++++++++++ .../v1/moriio/moriio_connector.py | 8 +- 2 files changed, 609 insertions(+), 4 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_moriio_connector.py diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py new file mode 100644 index 0000000000000..d5da774f78026 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -0,0 +1,605 @@ +import pytest + +from vllm.platforms import current_platform +import contextlib +import inspect +import os +import tempfile +import textwrap +import time +import uuid +from collections import defaultdict +from unittest.mock import patch +from unittest.mock import MagicMock +import pytest +import ray +import torch +import msgspec +from vllm import LLM +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiKVConnectorStats, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + KVConnectorRole, + ) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConstants +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOAgentMetadata + +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + MoRIIOConnector, + MoRIIOConnectorScheduler, + MoRIIOConnectorWorker +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConnectorMetadata + +from vllm.distributed.kv_transfer.kv_transfer_state import ( + ensure_kv_transfer_shutdown, + has_kv_transfer_group, +) +from vllm.forward_context import ForwardContext +from vllm.platforms.interface import Platform +from vllm.sampling_params import SamplingParams +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.request import RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + +from .utils import create_request, create_scheduler +from vllm.config import ( + CacheConfig, + DeviceConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) + +class FakeMorIIOWrapper(): + + def __init__(self, *args, **kwargs): + pass + + + def set_moriio_engine(self, moriio_engine): + pass + + def set_backend_type(self, backend_type): + pass + + def get_agent_metadata(self): + pass + + def register_remote_engine(self, remote_packed_engine_metadata): + pass + + def register_local_tensor(self, tensor: torch.Tensor): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + try: + self.local_memory_metadata = self.moriio_engine.register_torch_tensor( + tensor + ) + assert self.local_memory_metadata is not None, ( + "register_torch_tensor returned None" + ) + local_memory_metadata_packed = self.local_memory_metadata.pack() + except Exception as e: + raise MoRIIOError(f"Failed to register local memory: {e}") from e + self.local_memory_registered = True + return local_memory_metadata_packed + + def get_unpack_memory_metadata(self, packed_memory_metadata): + pass + + def build_session(self, local_memory_metadata, remote_memory_metadata): + pass + + def read_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + pass + + def write_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + pass + + + def write_remote_data_single( + self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 + ): + pass + + def waiting_for_transfer_complete(self): + pass + + def async_wait_reqid(self): + pass + + def _handle_message(self, msg: bytes): + pass + + def _handle_structured_message(self, data: dict): + pass + + def _handle_completion_message(self, msg: str): + pass + + def send_notify(self, req_ids, remote_ip, remote_port): + pass + + def pop_finished_req_ids(self): + pass + + def pop_finished_write_req_ids(self): + pass + + def shutdown(self): + pass + + +class FakeMoriIOConnectorWorker(MoRIIOConnectorWorker): + REMOTE_ENGINE_ID = "remote_engine" + + def __init__( + self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs + ): + super().__init__(*args, **kwargs) + self._hand_shake_latency = hand_shake_latency + self.kv_cache_layout = kv_cache_layout + + +from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + block_size: int = 16, + max_model_len: int = 10000, + enable_chunked_prefill: bool = True, + enable_permute_local_kv: bool = False, + role="kv_consumer" + # role="kv_producer" + +) -> VllmConfig: + """Initialize VllmConfig For Testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, + ) + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="bfloat16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="MoRIIOConnector", + kv_role=role, + enable_permute_local_kv=enable_permute_local_kv, + ) + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu"), + ) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + """Initialize Scheduler For Testing.""" + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, + ) + +@pytest.fixture +def moriio_read_mode(): + """Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests.""" + os.environ['VLLM_MORIIO_CONNECTOR_READ_MODE'] = 'True' + yield + # Cleanup after test + os.environ.pop('VLLM_MORIIO_CONNECTOR_READ_MODE', None) + +def test_write_mode_basic_interface(): + """Unit test for basic MoriioConnector interface functionality.""" + + # Test Prefill wirte metadata + vllm_config = create_vllm_config(role="kv_consumer") + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + do_remote_prefill=False + ) + request_id = request.request_id + + scheduler.add_request(request) + + # Fake + request.kv_transfer_params['remote_notify_port']=4789 + request.kv_transfer_params['remote_block_ids']=None + request.kv_transfer_params["remote_host"]="127.0.0.1" + request.kv_transfer_params["remote_port"]=4789 + request.kv_transfer_params["remote_handshake_port"]=4789 + request.kv_transfer_params["remote_engine_id"]="test_engine" + # Remote Prefill, triggers NixlConnectorMetadata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) + + assert len(kv_connector_metadata.reqs_to_save) == 1 + assert len(kv_connector_metadata.reqs_to_recv) == 0 + assert len(kv_connector_metadata.reqs_to_send) == 0 + assert request_id in kv_connector_metadata.reqs_to_save + req_meta = kv_connector_metadata.reqs_to_save[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): + assert block_id == block.block_id + + +def test_write_mode_chunk_prefill(): + """Unit test for basic MoriioConnector interface functionality.""" + MAX_NUM_BATCHED_TOKENS=64 + + NUM_TOKENS = MAX_NUM_BATCHED_TOKENS*2+MAX_NUM_BATCHED_TOKENS//2 + + # Test Prefill wirte metadata + vllm_config = create_vllm_config(max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer") + BLOCK_SIZE = vllm_config.cache_config.block_size + + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + do_remote_prefill=False + ) + request_id = request.request_id + + scheduler.add_request(request) + + # Fake + request.kv_transfer_params['remote_notify_port']=4789 + request.kv_transfer_params['remote_block_ids']=None + request.kv_transfer_params["remote_host"]="127.0.0.1" + request.kv_transfer_params["remote_port"]=4789 + request.kv_transfer_params["remote_handshake_port"]=4789 + request.kv_transfer_params["remote_engine_id"]="test_engine" + # Remote Prefill, triggers NixlConnectorMetadata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) + + assert len(kv_connector_metadata.reqs_to_save) == 1 + assert len(kv_connector_metadata.reqs_to_recv) == 0 + assert len(kv_connector_metadata.reqs_to_send) == 0 + assert request_id in kv_connector_metadata.reqs_to_save + req_meta = kv_connector_metadata.reqs_to_save[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): + assert block_id == block.block_id + +def test_read_mode_basic_interface(moriio_read_mode): + + # test decode read + vllm_config = create_vllm_config(role="kv_consumer") + scheduler = create_scheduler(vllm_config) + # + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=False, + do_remote_prefill=True + ) + request_id = request.request_id + + scheduler.add_request(request) + block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id] + # Fake + request.kv_transfer_params['remote_notify_port']=4789 + request.kv_transfer_params['remote_block_ids']=block_list + request.kv_transfer_params["remote_host"]="127.0.0.1" + request.kv_transfer_params["remote_port"]=4789 + request.kv_transfer_params["remote_handshake_port"]=4789 + request.kv_transfer_params["remote_engine_id"]="test_engine" + # Remote Prefill, triggers MorIIOConnectorMetadata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) + assert len(kv_connector_metadata.reqs_to_save) == 0 + assert len(kv_connector_metadata.reqs_to_recv) == 1 + assert len(kv_connector_metadata.reqs_to_send) == 0 + # assert len(kv_connector_metadata.reqs_to_save) == 1 + assert request_id in kv_connector_metadata.reqs_to_recv + req_meta = kv_connector_metadata.reqs_to_recv[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): + assert block_id == block.block_id + + +def test_register_kv_caches(): + from vllm.utils.network_utils import get_ip + + ROLE="kv_consumer" + IP=get_ip() + DEFAULT_PORT=6301 + vllm_config = create_vllm_config(role=ROLE) + TP_RANK=0 + DP_RANK=0 + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + backend_cls = AiterFlashAttentionBackend + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() + expected_base_addrs = [ + shared_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), + shared_tensor[0].data_ptr(), + + ] + mock_group = MagicMock() + mock_group.rank = TP_RANK # 设置 rank + mock_group.local_rank = TP_RANK + mock_group.world_size = 1 # 设置 world_size + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper" + ) as mock_moriio_wrapper, + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank", + return_value=0 + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size", + return_value=0 + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size", + return_value=0 + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group", + return_value=mock_group + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group", + return_value=mock_group + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", + FakeMorIIOWrapper, + ) + + ): + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update({ + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + }) + + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeMoriIOConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + + # Get the mock instance + mock_wrapper_instance = mock_moriio_wrapper.return_value + # connector.connector_worker.moriio_wrapper = mock_wrapper_instance + + # Reassure the shutdown() check that the thread is terminated + # mock_thread.return_value.is_alive.return_value = False + from mori.io import ( + EngineDesc, + IOEngine, + MemoryDesc, + PollCqMode, + RdmaBackendConfig, + ) + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + shared_tensor[0].data_ptr() + unique_tensor[1].data_ptr() + shared_tensor[0].data_ptr() + + assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).data + assert unique_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer1"][0]).data + assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer2"][0]).data + expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" + + assert MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).engine_key ==expected_engine_key + +def test_moriio_handshake(): + from vllm.utils.network_utils import get_ip + + ROLE="kv_consumer" + IP=get_ip() + DEFAULT_PORT=6301 + vllm_config = create_vllm_config(role=ROLE) + TP_RANK=0 + DP_RANK=0 + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + backend_cls = AiterFlashAttentionBackend + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() + expected_base_addrs = [ + shared_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), + shared_tensor[0].data_ptr(), + + ] + mock_group = MagicMock() + mock_group.rank = TP_RANK + mock_group.local_rank = TP_RANK + mock_group.world_size = 1 + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank", + return_value=0 + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size", + return_value=0 + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size", + return_value=0 + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group", + return_value=mock_group + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group", + return_value=mock_group + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", + FakeMorIIOWrapper, + ) + + ): + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update({ + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + "handshake_port":5670 + }) + + + + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) + + from vllm.utils.network_utils import ( + get_ip, + make_zmq_path, + make_zmq_socket, + ) + from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import zmq_ctx + import zmq + + + # Reassure the shutdown() check that the thread is terminated + # mock_thread.return_value.is_alive.return_value = False + from mori.io import ( + EngineDesc, + IOEngine, + MemoryDesc, + PollCqMode, + RdmaBackendConfig, + ) + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + # connector.layer_name_to_local_kv_cache_metadata["layer0"] expected_base_addrs = [ + + path = make_zmq_path("tcp", "127.0.0.1", 5670) + with zmq_ctx(zmq.DEALER, path) as sock: + sock.send(MoRIIOConstants.GET_META_MSG) + received_frame = sock.recv_multipart() + + if len(received_frame) != 2 or received_frame[0] != b"": + raise HandshakeError(f"Unexpected frame! {received_frame = }") + + metadata_bytes = received_frame[1] + decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) + metadata = decoder.decode(metadata_bytes) + assert isinstance(metadata, MoRIIOAgentMetadata) + + \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 4b6bd906d5d44..96e4c378b0b67 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -342,8 +342,8 @@ class MoRIIOConnectorScheduler: local_block_ids = blocks.get_block_ids()[0] self._reqs_need_save[request.request_id] = (request, local_block_ids) - if params is not None and params.get("do_remote_prefill"): - if self.mode == MoRIIOMode.READ: + if params is not None and params.get("do_remote_prefill"): # + if self.mode == MoRIIOMode.READ: #read mode decode if remote_block_ids := params.get("remote_block_ids"): if all( p in params @@ -373,7 +373,7 @@ class MoRIIOConnectorScheduler: ) else: - assert request.kv_transfer_params is not None, ( + assert request.kv_transfer_params is not None, ( #write mode decode "kv_transfer_params should not be None" ) @@ -890,7 +890,7 @@ class MoRIIOConnectorWorker: layer_name_to_local_kv_cache_metadata: dict, ): """Background thread for getting new MoRIIO handshakes.""" - + logger.info("tmp") encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) From 94a920fb0c119e8a5baeb052a8f4b8045575b00f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 09:04:37 +0000 Subject: [PATCH 56/62] format ut Signed-off-by: inkcherry --- .../unit/test_moriio_connector.py | 450 +++++++----------- 1 file changed, 171 insertions(+), 279 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index d5da774f78026..25b5663098272 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -1,53 +1,14 @@ -import pytest - -from vllm.platforms import current_platform -import contextlib -import inspect +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -import tempfile -import textwrap -import time -import uuid -from collections import defaultdict -from unittest.mock import patch -from unittest.mock import MagicMock -import pytest -import ray -import torch +from unittest.mock import MagicMock, patch + import msgspec -from vllm import LLM -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats -from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( - MultiKVConnectorStats, -) -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( - KVConnectorRole, - ) -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConstants -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOAgentMetadata +import pytest +import torch +import zmq -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( - MoRIIOConnector, - MoRIIOConnectorScheduler, - MoRIIOConnectorWorker -) -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConnectorMetadata - -from vllm.distributed.kv_transfer.kv_transfer_state import ( - ensure_kv_transfer_shutdown, - has_kv_transfer_group, -) -from vllm.forward_context import ForwardContext -from vllm.platforms.interface import Platform -from vllm.sampling_params import SamplingParams -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput -from vllm.v1.request import RequestStatus -from vllm.v1.structured_output import StructuredOutputManager - -from .utils import create_request, create_scheduler +from tests.conftest import _find_free_port from vllm.config import ( CacheConfig, DeviceConfig, @@ -56,13 +17,68 @@ from vllm.config import ( SchedulerConfig, VllmConfig, ) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + MoRIIOAgentMetadata, + MoRIIOConnectorMetadata, + MoRIIOConstants, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + KVConnectorRole, + MoRIIOConnector, + MoRIIOConnectorWorker, +) +from vllm.utils.network_utils import ( + get_ip, + make_zmq_path, +) -class FakeMorIIOWrapper(): +from .utils import create_request, create_scheduler + +@pytest.fixture +def mock_parallel_groups(): + """Mock parallel group functions.""" + mock_group = MagicMock() + mock_group.rank = 0 + mock_group.local_rank = 0 + mock_group.world_size = 1 + + with ( + patch.multiple( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common", + get_tensor_model_parallel_rank=MagicMock(return_value=0), + get_tensor_model_parallel_world_size=MagicMock(return_value=0), + ), + patch.multiple( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector", + get_tensor_model_parallel_world_size=MagicMock(return_value=0), + get_world_group=MagicMock(return_value=mock_group), + get_tp_group=MagicMock(return_value=mock_group), + ), + ): + yield mock_group + + +def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789): + """Setup KV transfer parameters for a request.""" + request.kv_transfer_params.update( + { + "remote_notify_port": fake_port, + "remote_block_ids": None, + "remote_host": remote_host, + "remote_port": fake_port, + "remote_handshake_port": fake_port, + "remote_engine_id": "test_engine", + } + ) + return request + + +class FakeMorIIOWrapper: def __init__(self, *args, **kwargs): pass - def set_moriio_engine(self, moriio_engine): pass @@ -76,19 +92,7 @@ class FakeMorIIOWrapper(): pass def register_local_tensor(self, tensor: torch.Tensor): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - try: - self.local_memory_metadata = self.moriio_engine.register_torch_tensor( - tensor - ) - assert self.local_memory_metadata is not None, ( - "register_torch_tensor returned None" - ) - local_memory_metadata_packed = self.local_memory_metadata.pack() - except Exception as e: - raise MoRIIOError(f"Failed to register local memory: {e}") from e - self.local_memory_registered = True - return local_memory_metadata_packed + pass def get_unpack_memory_metadata(self, packed_memory_metadata): pass @@ -98,18 +102,17 @@ class FakeMorIIOWrapper(): def read_remote_data( self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): + ): pass def write_remote_data( self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): + ): pass - def write_remote_data_single( self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 - ): + ): pass def waiting_for_transfer_complete(self): @@ -140,19 +143,15 @@ class FakeMorIIOWrapper(): pass -class FakeMoriIOConnectorWorker(MoRIIOConnectorWorker): +class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" def __init__( self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs ): super().__init__(*args, **kwargs) - self._hand_shake_latency = hand_shake_latency - self.kv_cache_layout = kv_cache_layout -from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput - def create_vllm_config( model: str = "facebook/opt-125m", max_num_seqs: int = 16, @@ -161,9 +160,7 @@ def create_vllm_config( max_model_len: int = 10000, enable_chunked_prefill: bool = True, enable_permute_local_kv: bool = False, - role="kv_consumer" - # role="kv_producer" - + role="kv_consumer", ) -> VllmConfig: """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( @@ -198,46 +195,20 @@ def create_vllm_config( kv_transfer_config=kv_transfer_config, device_config=DeviceConfig("cpu"), ) -from vllm.v1.kv_cache_interface import ( - FullAttentionSpec, - KVCacheConfig, - KVCacheGroupSpec, -) -def create_scheduler( - vllm_config: VllmConfig, - num_blocks: int = 10000, -) -> Scheduler: - """Initialize Scheduler For Testing.""" - block_size = vllm_config.cache_config.block_size - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec( - ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) - ) - ], - ) - vllm_config.cache_config.num_gpu_blocks = num_blocks - return Scheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - block_size=block_size, - ) + @pytest.fixture def moriio_read_mode(): """Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests.""" - os.environ['VLLM_MORIIO_CONNECTOR_READ_MODE'] = 'True' + os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True" yield # Cleanup after test - os.environ.pop('VLLM_MORIIO_CONNECTOR_READ_MODE', None) + os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None) + def test_write_mode_basic_interface(): """Unit test for basic MoriioConnector interface functionality.""" - + # Test Prefill wirte metadata vllm_config = create_vllm_config(role="kv_consumer") scheduler = create_scheduler(vllm_config) @@ -252,19 +223,15 @@ def test_write_mode_basic_interface(): block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=True, - do_remote_prefill=False + do_remote_prefill=False, ) request_id = request.request_id scheduler.add_request(request) - - # Fake - request.kv_transfer_params['remote_notify_port']=4789 - request.kv_transfer_params['remote_block_ids']=None - request.kv_transfer_params["remote_host"]="127.0.0.1" - request.kv_transfer_params["remote_port"]=4789 - request.kv_transfer_params["remote_handshake_port"]=4789 - request.kv_transfer_params["remote_engine_id"]="test_engine" + + # Fake Config + request = _setup_kv_transfer_request(request) + # Remote Prefill, triggers NixlConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata @@ -288,36 +255,34 @@ def test_write_mode_basic_interface(): def test_write_mode_chunk_prefill(): """Unit test for basic MoriioConnector interface functionality.""" - MAX_NUM_BATCHED_TOKENS=64 + MAX_NUM_BATCHED_TOKENS = 64 - NUM_TOKENS = MAX_NUM_BATCHED_TOKENS*2+MAX_NUM_BATCHED_TOKENS//2 + NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2 # Test Prefill wirte metadata - vllm_config = create_vllm_config(max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer") + vllm_config = create_vllm_config( + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer" + ) BLOCK_SIZE = vllm_config.cache_config.block_size scheduler = create_scheduler(vllm_config) # 2 Full Blocks and 1 Half Block. - + request = create_request( request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=True, - do_remote_prefill=False + do_remote_prefill=False, ) request_id = request.request_id scheduler.add_request(request) - - # Fake - request.kv_transfer_params['remote_notify_port']=4789 - request.kv_transfer_params['remote_block_ids']=None - request.kv_transfer_params["remote_host"]="127.0.0.1" - request.kv_transfer_params["remote_port"]=4789 - request.kv_transfer_params["remote_handshake_port"]=4789 - request.kv_transfer_params["remote_engine_id"]="test_engine" + + # Fake Config + + request = _setup_kv_transfer_request(request) # Remote Prefill, triggers NixlConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata @@ -338,8 +303,8 @@ def test_write_mode_chunk_prefill(): ): assert block_id == block.block_id + def test_read_mode_basic_interface(moriio_read_mode): - # test decode read vllm_config = create_vllm_config(role="kv_consumer") scheduler = create_scheduler(vllm_config) @@ -354,20 +319,20 @@ def test_read_mode_basic_interface(moriio_read_mode): block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=False, - do_remote_prefill=True + do_remote_prefill=True, ) request_id = request.request_id scheduler.add_request(request) - block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id] - # Fake - request.kv_transfer_params['remote_notify_port']=4789 - request.kv_transfer_params['remote_block_ids']=block_list - request.kv_transfer_params["remote_host"]="127.0.0.1" - request.kv_transfer_params["remote_port"]=4789 - request.kv_transfer_params["remote_handshake_port"]=4789 - request.kv_transfer_params["remote_engine_id"]="test_engine" + block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].req_to_blocks[request_id] + # Fake kv config + request = _setup_kv_transfer_request(request) + request.kv_transfer_params["remote_block_ids"] = block_list + # Remote Prefill, triggers MorIIOConnectorMetadata. + scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None @@ -375,7 +340,6 @@ def test_read_mode_basic_interface(moriio_read_mode): assert len(kv_connector_metadata.reqs_to_save) == 0 assert len(kv_connector_metadata.reqs_to_recv) == 1 assert len(kv_connector_metadata.reqs_to_send) == 0 - # assert len(kv_connector_metadata.reqs_to_save) == 1 assert request_id in kv_connector_metadata.reqs_to_recv req_meta = kv_connector_metadata.reqs_to_recv[request_id] @@ -388,16 +352,15 @@ def test_read_mode_basic_interface(moriio_read_mode): assert block_id == block.block_id -def test_register_kv_caches(): - from vllm.utils.network_utils import get_ip - - ROLE="kv_consumer" - IP=get_ip() - DEFAULT_PORT=6301 +def test_register_kv_caches(mock_parallel_groups): + ROLE = "kv_consumer" + IP = get_ip() vllm_config = create_vllm_config(role=ROLE) - TP_RANK=0 - DP_RANK=0 - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + DEFAULT_PORT = 6301 + TP_RANK = 0 + DP_RANK = 0 + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + backend_cls = AiterFlashAttentionBackend # Create test kv cache tensors using proper backend shape @@ -411,97 +374,80 @@ def test_register_kv_caches(): "layer1": unique_tensor, "layer2": shared_tensor, } - - # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - shared_tensor[0].data_ptr(), - ] - mock_group = MagicMock() - mock_group.rank = TP_RANK # 设置 rank - mock_group.local_rank = TP_RANK - mock_group.world_size = 1 # 设置 world_size with ( patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper" - ) as mock_moriio_wrapper, - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size", - return_value=0 + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event" ), patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group", - return_value=mock_group + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread" ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group", - return_value=mock_group - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", - FakeMorIIOWrapper, - ) - ): # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update({ - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - }) - + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + } + ) + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) - connector.connector_worker = FakeMoriIOConnectorWorker( + connector.connector_worker = FakeMorIIOConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) - # Get the mock instance - mock_wrapper_instance = mock_moriio_wrapper.return_value - # connector.connector_worker.moriio_wrapper = mock_wrapper_instance - - # Reassure the shutdown() check that the thread is terminated - # mock_thread.return_value.is_alive.return_value = False from mori.io import ( - EngineDesc, - IOEngine, MemoryDesc, - PollCqMode, - RdmaBackendConfig, ) + # Execute register_kv_caches connector.register_kv_caches(kv_caches) shared_tensor[0].data_ptr() unique_tensor[1].data_ptr() shared_tensor[0].data_ptr() - assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).data - assert unique_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer1"][0]).data - assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer2"][0]).data + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).data + ) + assert ( + unique_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer1" + ][0] + ).data + ) + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer2" + ][0] + ).data + ) expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" - assert MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).engine_key ==expected_engine_key + assert ( + MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).engine_key + == expected_engine_key + ) -def test_moriio_handshake(): - from vllm.utils.network_utils import get_ip - ROLE="kv_consumer" - IP=get_ip() - DEFAULT_PORT=6301 +def test_moriio_handshake(mock_parallel_groups): + ROLE = "kv_consumer" vllm_config = create_vllm_config(role=ROLE) - TP_RANK=0 - DP_RANK=0 - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + backend_cls = AiterFlashAttentionBackend # Create test kv cache tensors using proper backend shape @@ -515,91 +461,37 @@ def test_moriio_handshake(): "layer1": unique_tensor, "layer2": shared_tensor, } - - # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - shared_tensor[0].data_ptr(), - ] - mock_group = MagicMock() - mock_group.rank = TP_RANK - mock_group.local_rank = TP_RANK - mock_group.world_size = 1 with ( patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank", - return_value=0 + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", + FakeMorIIOWrapper, ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group", - return_value=mock_group - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group", - return_value=mock_group - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", - FakeMorIIOWrapper, - ) - ): + handshake_port = _find_free_port() # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update({ - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - "handshake_port":5670 - }) - + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + "handshake_port": handshake_port, + } + ) - connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) - - from vllm.utils.network_utils import ( - get_ip, - make_zmq_path, - make_zmq_socket, - ) - from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import zmq_ctx - import zmq - - # Reassure the shutdown() check that the thread is terminated - # mock_thread.return_value.is_alive.return_value = False - from mori.io import ( - EngineDesc, - IOEngine, - MemoryDesc, - PollCqMode, - RdmaBackendConfig, - ) # Execute register_kv_caches connector.register_kv_caches(kv_caches) - # connector.layer_name_to_local_kv_cache_metadata["layer0"] expected_base_addrs = [ - - path = make_zmq_path("tcp", "127.0.0.1", 5670) + path = make_zmq_path("tcp", "127.0.0.1", handshake_port) with zmq_ctx(zmq.DEALER, path) as sock: sock.send(MoRIIOConstants.GET_META_MSG) received_frame = sock.recv_multipart() if len(received_frame) != 2 or received_frame[0] != b"": - raise HandshakeError(f"Unexpected frame! {received_frame = }") + raise ValueError(f"Unexpected frame! {received_frame = }") metadata_bytes = received_frame[1] decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) metadata = decoder.decode(metadata_bytes) assert isinstance(metadata, MoRIIOAgentMetadata) - - \ No newline at end of file From b36893b3056256e6929abc441f40e303937930f1 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 10:09:45 +0000 Subject: [PATCH 57/62] refine ut Signed-off-by: inkcherry --- .../unit/test_moriio_connector.py | 141 ++++++++++++------ 1 file changed, 94 insertions(+), 47 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index 25b5663098272..c31d5d843e85a 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib.util import os from unittest.mock import MagicMock, patch @@ -28,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import MoRIIOConnector, MoRIIOConnectorWorker, ) +from vllm.platforms import current_platform from vllm.utils.network_utils import ( get_ip, make_zmq_path, @@ -35,10 +37,17 @@ from vllm.utils.network_utils import ( from .utils import create_request, create_scheduler +aiter_available = importlib.util.find_spec("aiter") is not None +mori_available = importlib.util.find_spec("mori") is not None +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and mori_available), + reason="MoRIIOs are only available on ROCm with aiter package installed", +) + @pytest.fixture def mock_parallel_groups(): - """Mock parallel group functions.""" + """Mock tensor/data parallel group functions for single-rank tests.""" mock_group = MagicMock() mock_group.rank = 0 mock_group.local_rank = 0 @@ -76,6 +85,7 @@ def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789) class FakeMorIIOWrapper: + # A fake MoRIIOWrapper for testing purposes def __init__(self, *args, **kwargs): pass @@ -144,6 +154,7 @@ class FakeMorIIOWrapper: class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker): + # Define a fake remote engine id for testing REMOTE_ENGINE_ID = "remote_engine" def __init__( @@ -162,7 +173,7 @@ def create_vllm_config( enable_permute_local_kv: bool = False, role="kv_consumer", ) -> VllmConfig: - """Initialize VllmConfig For Testing.""" + """Initialize VllmConfig for testing.""" scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, @@ -199,18 +210,18 @@ def create_vllm_config( @pytest.fixture def moriio_read_mode(): - """Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests.""" + """Force the connector into read mode via env for tests.""" os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True" yield # Cleanup after test os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None) -def test_write_mode_basic_interface(): - """Unit test for basic MoriioConnector interface functionality.""" +def test_write_mode_saves_local_block_ids(): + """Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save.""" - # Test Prefill wirte metadata - vllm_config = create_vllm_config(role="kv_consumer") + # Setup Scheduler and Request + vllm_config = create_vllm_config(role="kv_producer") scheduler = create_scheduler(vllm_config) # 2 Full Blocks and 1 Half Block. @@ -235,13 +246,21 @@ def test_write_mode_basic_interface(): # Remote Prefill, triggers NixlConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None + assert kv_connector_metadata is not None, "kv_connector_metadata is None" assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_save) == 1 - assert len(kv_connector_metadata.reqs_to_recv) == 0 - assert len(kv_connector_metadata.reqs_to_send) == 0 - assert request_id in kv_connector_metadata.reqs_to_save + assert len(kv_connector_metadata.reqs_to_save) == 1, ( + "Unexpected number of reqs_to_save" + ) + assert len(kv_connector_metadata.reqs_to_recv) == 0, ( + "Unexpected number of reqs_to_recv" + ) + assert len(kv_connector_metadata.reqs_to_send) == 0, ( + "Unexpected number of reqs_to_send" + ) + assert request_id in kv_connector_metadata.reqs_to_save, ( + "Request ID not in reqs_to_save" + ) req_meta = kv_connector_metadata.reqs_to_save[request_id] for block_id, block in zip( @@ -250,18 +269,17 @@ def test_write_mode_basic_interface(): request_id ], ): - assert block_id == block.block_id + assert block_id == block.block_id, f"{block_id} != {block.block_id}" -def test_write_mode_chunk_prefill(): - """Unit test for basic MoriioConnector interface functionality.""" +def test_write_mode_with_chunked_prefill_saves_local_block_ids(): + """Write mode with chunked prefill still records correct local block ids.""" + # Setup Scheduler and Request MAX_NUM_BATCHED_TOKENS = 64 - NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2 - # Test Prefill wirte metadata vllm_config = create_vllm_config( - max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer" + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer" ) BLOCK_SIZE = vllm_config.cache_config.block_size @@ -281,18 +299,22 @@ def test_write_mode_chunk_prefill(): scheduler.add_request(request) # Fake Config - request = _setup_kv_transfer_request(request) - # Remote Prefill, triggers NixlConnectorMetadata. - scheduler_output = scheduler.schedule() - kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None - assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_save) == 1 - assert len(kv_connector_metadata.reqs_to_recv) == 0 - assert len(kv_connector_metadata.reqs_to_send) == 0 - assert request_id in kv_connector_metadata.reqs_to_save + # Remote Prefill with chunked prefill, triggers multiple schedules. + expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)] + kv_connector_metadata = None + for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts): + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + + assert len(kv_connector_metadata.reqs_to_save) == expected_save + assert len(kv_connector_metadata.reqs_to_recv) == expected_recv + assert len(kv_connector_metadata.reqs_to_send) == expected_send + assert kv_connector_metadata is not None, "kv_connector_metadata is None" + assert request_id in kv_connector_metadata.reqs_to_save, ( + "Request ID not in reqs_to_save" + ) req_meta = kv_connector_metadata.reqs_to_save[request_id] for block_id, block in zip( @@ -301,14 +323,16 @@ def test_write_mode_chunk_prefill(): request_id ], ): - assert block_id == block.block_id + assert block_id == block.block_id, f"{block_id} != {block.block_id}" -def test_read_mode_basic_interface(moriio_read_mode): - # test decode read +def test_read_mode_loads_remote_block_ids(moriio_read_mode): + """Read mode loads remote block ids into local cache mapping.""" + + # Setup Scheduler and Request vllm_config = create_vllm_config(role="kv_consumer") scheduler = create_scheduler(vllm_config) - # + # 2 Full Blocks and 1 Half Block. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 @@ -327,20 +351,32 @@ def test_read_mode_basic_interface(moriio_read_mode): block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 ].req_to_blocks[request_id] - # Fake kv config + request = _setup_kv_transfer_request(request) + + # Set remote block ids to be fetched. request.kv_transfer_params["remote_block_ids"] = block_list # Remote Prefill, triggers MorIIOConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None - assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_save) == 0 - assert len(kv_connector_metadata.reqs_to_recv) == 1 - assert len(kv_connector_metadata.reqs_to_send) == 0 - assert request_id in kv_connector_metadata.reqs_to_recv + assert kv_connector_metadata is not None, "kv_connector_metadata is None" + assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), ( + "kv_connector_metadata is not MoRIIOConnectorMetadata" + ) + assert len(kv_connector_metadata.reqs_to_save) == 0, ( + "Unexpected number of reqs_to_save" + ) + assert len(kv_connector_metadata.reqs_to_recv) == 1, ( + "Unexpected number of reqs_to_recv" + ) + assert len(kv_connector_metadata.reqs_to_send) == 0, ( + "Unexpected number of reqs_to_send" + ) + assert request_id in kv_connector_metadata.reqs_to_recv, ( + "Request ID not in reqs_to_recv" + ) req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( @@ -349,10 +385,14 @@ def test_read_mode_basic_interface(moriio_read_mode): request_id ], ): - assert block_id == block.block_id + assert block_id == block.block_id, f"{block_id} != {block.block_id}" +@pytest.mark.skipif( + not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" +) def test_register_kv_caches(mock_parallel_groups): + """Test that MoRIIOConnector.register_kv_caches correctly registers kv caches.""" ROLE = "kv_consumer" IP = get_ip() vllm_config = create_vllm_config(role=ROLE) @@ -403,10 +443,8 @@ def test_register_kv_caches(mock_parallel_groups): # Execute register_kv_caches connector.register_kv_caches(kv_caches) - shared_tensor[0].data_ptr() - unique_tensor[1].data_ptr() - shared_tensor[0].data_ptr() + # Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata assert ( shared_tensor.data_ptr() == MemoryDesc.unpack( @@ -431,8 +469,9 @@ def test_register_kv_caches(mock_parallel_groups): ][0] ).data ) - expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" + # Verify engine keys + expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" assert ( MemoryDesc.unpack( connector.connector_worker.layer_name_to_local_kv_cache_metadata[ @@ -443,7 +482,12 @@ def test_register_kv_caches(mock_parallel_groups): ) -def test_moriio_handshake(mock_parallel_groups): +@pytest.mark.skipif( + not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" +) +def test_moriio_handshake_returns_metadata(mock_parallel_groups): + """MoRIIO handshake socket returns valid agent metadata over ZMQ.""" + ROLE = "kv_consumer" vllm_config = create_vllm_config(role=ROLE) from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend @@ -478,11 +522,12 @@ def test_moriio_handshake(mock_parallel_groups): "handshake_port": handshake_port, } ) - connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) # Execute register_kv_caches connector.register_kv_caches(kv_caches) + + # Connect to handshake socket and request metadata path = make_zmq_path("tcp", "127.0.0.1", handshake_port) with zmq_ctx(zmq.DEALER, path) as sock: sock.send(MoRIIOConstants.GET_META_MSG) @@ -494,4 +539,6 @@ def test_moriio_handshake(mock_parallel_groups): metadata_bytes = received_frame[1] decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) metadata = decoder.decode(metadata_bytes) - assert isinstance(metadata, MoRIIOAgentMetadata) + assert isinstance(metadata, MoRIIOAgentMetadata), ( + "Decoded metadata is not MoRIIOAgentMetadata" + ) From 78d1683957fc36ea0accb127d2206d7cc248d4c6 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 10:11:23 +0000 Subject: [PATCH 58/62] fix typo Signed-off-by: inkcherry --- tests/v1/kv_connector/unit/test_moriio_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index c31d5d843e85a..df2e691f5bb3a 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -243,7 +243,7 @@ def test_write_mode_saves_local_block_ids(): # Fake Config request = _setup_kv_transfer_request(request) - # Remote Prefill, triggers NixlConnectorMetadata. + # Remote Prefill, triggers MoRIIOConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None, "kv_connector_metadata is None" From d2a18332b7afbbf0055f0a8a630da4a62f2dd64d Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 10:47:17 +0000 Subject: [PATCH 59/62] format Signed-off-by: inkcherry --- .../moriio_toy_proxy_server.py | 21 ++++++++++++------- .../v1/moriio/moriio_connector.py | 8 +++---- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py index 98481787268c6..a9feb82267f04 100644 --- a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py +++ b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py @@ -210,11 +210,16 @@ async def handle_request(): prefill_instance_endpoint = None decode_instance_endpoint = None - + error_msg = ( + "Service Unavailable: No prefill or decode instances are registered." + ) if not prefill_instances or not decode_instances: return await make_response( - ("Service Unavailable: No prefill or decode instances are registered.", - 503)) + ( + error_msg, + 503, + ) + ) pid = request_nums % len(prefill_instances) did = request_nums % len(decode_instances) prefill_instance_endpoint = prefill_instances[pid] @@ -297,10 +302,12 @@ async def handle_request(): return response except Exception as e: logger.exception("An error occurred while handling the request: %s", e) - return await make_response(( - f"Internal Server Error: {e!s}", - 500, - )) + return await make_response( + ( + f"Internal Server Error: {e!s}", + 500, + ) + ) if __name__ == "__main__": diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 96e4c378b0b67..4b6bd906d5d44 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -342,8 +342,8 @@ class MoRIIOConnectorScheduler: local_block_ids = blocks.get_block_ids()[0] self._reqs_need_save[request.request_id] = (request, local_block_ids) - if params is not None and params.get("do_remote_prefill"): # - if self.mode == MoRIIOMode.READ: #read mode decode + if params is not None and params.get("do_remote_prefill"): + if self.mode == MoRIIOMode.READ: if remote_block_ids := params.get("remote_block_ids"): if all( p in params @@ -373,7 +373,7 @@ class MoRIIOConnectorScheduler: ) else: - assert request.kv_transfer_params is not None, ( #write mode decode + assert request.kv_transfer_params is not None, ( "kv_transfer_params should not be None" ) @@ -890,7 +890,7 @@ class MoRIIOConnectorWorker: layer_name_to_local_kv_cache_metadata: dict, ): """Background thread for getting new MoRIIO handshakes.""" - logger.info("tmp") + encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) From 8c629bf22e872983260e9f63cc0cff6a6d0d8be9 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 11:15:15 +0000 Subject: [PATCH 60/62] ci Signed-off-by: inkcherry --- docker/Dockerfile.rocm_base | 19 ++++++++++++++++++- .../installation/gpu.rocm.inc.md | 16 +++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index c5e94ee1f6928..c820761b6b215 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -11,6 +11,8 @@ ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG AITER_BRANCH="6af8b687" ARG AITER_REPO="https://github.com/ROCm/aiter.git" +ARG MORI_BRANCH="2d02c6a9" +ARG MORI_REPO="https://github.com/ROCm/mori.git" FROM ${BASE_IMAGE} AS base @@ -20,6 +22,7 @@ ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151 ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} ENV AITER_ROCM_ARCH=gfx942;gfx950 +ENV MORI_GPU_ARCHS=gfx942;gfx950 # Required for RCCL in ROCm7.1 ENV HSA_NO_SCRATCH_RECLAIM=1 @@ -33,7 +36,7 @@ ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies RUN apt-get update -y \ - && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \ + && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \ && for i in 1 2 3; do \ add-apt-repository -y ppa:deadsnakes/ppa && break || \ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ @@ -67,6 +70,18 @@ RUN cd /opt/rocm/share/amd_smi \ && pip wheel . --wheel-dir=dist RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install +FROM base AS build_mori +ARG MORI_BRANCH +ARG MORI_REPO +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN git clone ${MORI_REPO} +RUN cd mori \ + && git checkout ${MORI_BRANCH} \ + && git submodule update --init --recursive \ + && python3 setup.py bdist_wheel --dist-dir=dist && ls /app/mori/dist/*.whl +RUN mkdir -p /app/install && cp /app/mori/dist/*.whl /app/install + FROM base AS build_pytorch ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH @@ -132,6 +147,8 @@ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_mori,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs FROM base AS final RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \ diff --git a/docs/getting_started/installation/gpu.rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md index 21120cc6fcd98..032f33f0e9445 100644 --- a/docs/getting_started/installation/gpu.rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -99,8 +99,22 @@ Currently, there are no pre-built ROCm wheels. - You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose. - The validated `$AITER_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base). +4. If you want to use MORI for EP or PD disaggregation, you can install [MORI](https://github.com/ROCm/mori) using the following steps: -4. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps: + ```bash + git clone https://github.com/ROCm/mori.git + cd mori + git checkout $MORI_BRANCH_OR_COMMIT + git submodule sync; git submodule update --init --recursive + MORI_GPU_ARCHS="gfx942;gfx950" python3 install . + ``` + + !!! note + - You will need to config the `$MORI_BRANCH_OR_COMMIT` for your purpose. + - The validated `$MORI_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base). + + +5. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps: ???+ console "Commands" From d16eec7aae24391a8cf2466e642a46cc920d7c2c Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 11:19:45 +0000 Subject: [PATCH 61/62] doc Signed-off-by: inkcherry --- docs/getting_started/installation/gpu.rocm.inc.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/getting_started/installation/gpu.rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md index 032f33f0e9445..f2b8a7a811f9b 100644 --- a/docs/getting_started/installation/gpu.rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -98,7 +98,8 @@ Currently, there are no pre-built ROCm wheels. !!! note - You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose. - The validated `$AITER_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base). - + + 4. If you want to use MORI for EP or PD disaggregation, you can install [MORI](https://github.com/ROCm/mori) using the following steps: ```bash From 2cc09f5ea62622170026caeab5d7e9c582324696 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 11:26:37 +0000 Subject: [PATCH 62/62] fix main branch Signed-off-by: inkcherry --- tests/v1/kv_connector/unit/test_moriio_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index df2e691f5bb3a..1cc6988635d8d 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -179,6 +179,7 @@ def create_vllm_config( max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, enable_chunked_prefill=enable_chunked_prefill, + is_encoder_decoder=False, ) model_config = ModelConfig( model=model,