diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 42821d341d6b6..954a5153ff67d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -182,7 +182,8 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MoRIIOConnector", "vllm.distributed.kv_transfer.kv_connector.v1.moriio_connector", - "MoRIIOConnector") + "MoRIIOConnector", +) KVConnectorFactory.register_connector( "OffloadingConnector", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 35a686e7a8fd4..2856733484cab 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -11,7 +11,8 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional +from weakref import ref as weakref_ref import msgpack import msgspec @@ -20,21 +21,25 @@ import torch import zmq from vllm import envs -from vllm.attention.selector import get_attn_backend from vllm.attention.backends.registry import AttentionBackendEnum - +from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, get_world_group) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + get_world_group, +) from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus -from weakref import ref as weakref_ref if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -42,50 +47,55 @@ if TYPE_CHECKING: from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -from dataclasses import field -from queue import Empty, Queue -from enum import Enum import logging - +from dataclasses import field +from enum import Enum +from queue import Empty, Queue logger = init_logger(__name__) -Transfer = tuple[int, float] +Transfer = tuple[int, float] EngineId = str ReqId = str class MoRIIOConstants: """Constants for MoRIIO connector.""" - + # ZMQ message types GET_META_MSG = b"get_meta_msg" POP_DONE_RECV = b"pop_done_recv" OVER = b"OVER" COMPLETION_PREFIX = "cmpl" - # Default GPU count per node for standard configurations - RANK_PER_NODE = 8 - - - + RANK_PER_NODE = 8 + + try: import mori - from mori.io import (BackendType, EngineDesc, IOEngine, IOEngineConfig, - MemoryDesc, StatusCode) + from mori.io import ( + BackendType, + EngineDesc, + IOEngine, + IOEngineConfig, + MemoryDesc, + StatusCode, + ) + logger.info("MoRIIO is available") MoRIIO_enabled = True except ImportError: logger.error("MoRIIO is not available") MoRIIO_enabled = False - + + @dataclass class WriteTask: request_id: str dst_engine_id: str local_block_ids: list[int] - remote_block_ids_hint: Optional[list[int]] + remote_block_ids_hint: list[int] | None layer_name: str event: torch.cuda.Event remote_notify_port: int @@ -93,9 +103,11 @@ class WriteTask: enqueue_time: float = field(default_factory=time.perf_counter) retried: int = 0 + @dataclass class LayerTransferPlan: """Plan for transferring a single layer.""" + request_id: str layer_name: str sess_idx: int @@ -103,10 +115,12 @@ class LayerTransferPlan: transfer_remote_offsets: list[int] transfer_sizes: list[int] use_batch: bool = True - + + @dataclass class RemoteAllocInfo: """Information about remote block allocation.""" + block_ids: list[int] writes_done: int = 0 decode_dp_rank: int = 0 @@ -118,15 +132,16 @@ class ROLE(Enum): CONSUMER = "consumer" NOTINIT = "notinit" + class RoleManager: """Manages role state across the connector.""" - + _instance: Optional["RoleManager"] = None _lock = threading.Lock() - + def __init__(self) -> None: self._role: ROLE = ROLE.NOTINIT - + @classmethod def get_instance(cls) -> "RoleManager": if cls._instance is None: @@ -134,12 +149,12 @@ class RoleManager: if cls._instance is None: cls._instance = cls() return cls._instance - + def set_role(self, role: ROLE) -> None: """Set the current role.""" with self._lock: self._role = role - + def get_role(self) -> ROLE: """Get the current role.""" return self._role @@ -149,6 +164,7 @@ def set_role(role: ROLE): """Set the global role.""" RoleManager.get_instance().set_role(role) + def get_role() -> ROLE: """Get the global role.""" return RoleManager.get_instance().get_role() @@ -158,30 +174,37 @@ class MoRIIOMode(Enum): READ = "read" WRITE = "write" + class MoRIIOError(Exception): """Base exception for MoRIIO operations.""" + pass + class HandshakeError(MoRIIOError): """Exception raised when handshake fails.""" + pass + class TransferError(MoRIIOError): """Exception raised when transfer fails.""" + pass def get_moriio_mode() -> MoRIIOMode: - read_mode = os.environ.get('MORIIO_CONNECTOR_READ_MODE', 'false').lower() + read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower() # logger.info(f"MoRIIO Connector Read Mode = {read_mode}") - if read_mode in ('true', '1', 'yes', 'on'): + if read_mode in ("true", "1", "yes", "on"): return MoRIIOMode.READ else: return MoRIIOMode.WRITE -def get_port_offset(dp_rank: int,tp_rank: int, tp_size:int=1) -> int: - return ((dp_rank)*tp_size+tp_rank )%MoRIIOConstants.RANK_PER_NODE +def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: + return ((dp_rank) * tp_size + tp_rank) % MoRIIOConstants.RANK_PER_NODE + @dataclass class MoRIIOConfig: @@ -198,18 +221,16 @@ class MoRIIOConfig: dp_rank: int dp_size: int tp_size: int - + @classmethod def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": - - # Port Configuration: # local_ping_port -> Outgoing heartbeat to proxy(only rank0 need it) # proxy_ping_port -> Remote proxy's heartbeat ingress port # http_port -> Instance's HTTP service endpoint # local_kv_port -> KV service port for Mori engine # notify_port -> For synchronizing stages between nodes - + kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() @@ -217,53 +238,54 @@ class MoRIIOConfig: base_kv_port = int(kv_transfer_config.kv_port) base_ping_port = int(extra_config["local_ping_port"]) base_notify_port = int(extra_config["notify_port"]) - dp_size=vllm_config.parallel_config.data_parallel_size - tp_size=get_tensor_model_parallel_world_size() - port_offset=get_port_offset(dp_rank,tp_rank) - + dp_size = vllm_config.parallel_config.data_parallel_size + tp_size = get_tensor_model_parallel_world_size() + port_offset = get_port_offset(dp_rank, tp_rank) + return cls( local_ip=get_ip(), local_kv_port=base_kv_port + port_offset, proxy_ip=extra_config["proxy_ip"], proxy_port=int(extra_config["proxy_port"]), - local_ping_port=base_ping_port+port_offset, + local_ping_port=base_ping_port + port_offset, proxy_ping_port=int(extra_config["proxy_ping_port"]), - http_port=int(extra_config['http_port']), - handshake_port=int(extra_config['handshake_port']), + http_port=int(extra_config["http_port"]), + handshake_port=int(extra_config["handshake_port"]), notify_port=base_notify_port + port_offset, tp_rank=tp_rank, dp_rank=dp_rank, dp_size=dp_size, - tp_size=tp_size, + tp_size=tp_size, ) """Write task execution logic for MoRIIO connector.""" + class MoRIIOWriter: """Handles write operations for KV cache transfers. Implements distributed KV cache transfer using the MoRIIO library for RDMA-based communication between prefill and decode instances.""" - + def __init__(self, worker: "MoRIIOConnectorWorker"): """Initialize the writer. - + Args: worker: Reference to the parent worker """ # self.worker = worker - self._worker_ref: "weakref_ref[MoRIIOConnectorWorker]" = weakref_ref(worker) + self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) self._write_task_q: Queue[WriteTask] = Queue() self._write_worker_started = False self._write_worker_lock = threading.Lock() self._deferred_tasks: list[WriteTask] = [] - + @property def worker(self) -> "MoRIIOConnectorWorker": """Get the worker instance. - + Returns: The parent worker instance - + Raises: RuntimeError: If worker has been garbage collected """ @@ -271,47 +293,42 @@ class MoRIIOWriter: if worker is None: raise RuntimeError("Parent worker has been garbage collected") return worker - + def ensure_worker_started(self) -> None: """Ensure the background write worker is running.""" if self._write_worker_started: return self._write_worker_started = True with self._write_worker_lock: - thread = threading.Thread( - target=self._write_worker_loop, - daemon=True, - name="moriio-write-worker" + target=self._write_worker_loop, daemon=True, name="moriio-write-worker" ) thread.start() logger.info("Started MoRIIO write worker thread") - + def schedule_write(self, task: WriteTask) -> None: """Schedule a write task. - + Args: task: The write task to schedule """ self.ensure_worker_started() self._write_task_q.put(task) - + def _write_worker_loop(self) -> None: """Main loop for the write worker thread.""" logger.info("Write worker loop started") - + while True: # Process deferred tasks first self._process_deferred_tasks() - + # Get new task try: - task = self._write_task_q.get( - timeout=0.01 - ) + task = self._write_task_q.get(timeout=0.01) except Empty: continue - + # Check if remote blocks are ready if not self._is_remote_ready(task): # task.retry_count += 1 @@ -321,128 +338,114 @@ class MoRIIOWriter: # task.request_id, task.retry_count # ) continue - + # Execute the task self._execute_write_task(task) - - + def _process_deferred_tasks(self) -> None: """Process tasks that were previously deferred.""" if not self._deferred_tasks: return - + still_deferred: list[WriteTask] = [] for task in self._deferred_tasks: if self._is_remote_ready(task): - - self._execute_write_task(task) + self._execute_write_task(task) else: still_deferred.append(task) - + self._deferred_tasks = still_deferred - + def _is_remote_ready(self, task: WriteTask) -> bool: """Check if remote blocks are allocated for this task. - + Args: task: The write task - + Returns: True if remote blocks are ready """ - return (task.request_id in - self.worker.moriio_wrapper.done_remote_allocate_req_dict) - + return ( + task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict + ) + def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: """Get remote allocation info for a request. - + Args: request_id: The request ID - + Returns: Remote allocation information - + Raises: KeyError: If allocation info is missing """ try: - return self.worker.moriio_wrapper.done_remote_allocate_req_dict[ - request_id - ] + return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id] except KeyError as e: raise KeyError( f"Remote allocation info missing for request {request_id}" ) from e - + def _execute_write_task(self, task: WriteTask) -> None: """Execute a single write task. - + Args: task: The write task to execute - + """ # Get remote allocation info request_info = self._get_remote_alloc_info(task.request_id) - + if request_info.block_ids is None: - logger.debug( - "Request %s remote block IDs not ready", - task.request_id - ) + logger.debug("Request %s remote block IDs not ready", task.request_id) return - + # Wait for CUDA event task.event.synchronize() - + # Update engine ID with DP rank - task.dst_engine_id = ( - f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" - ) - + task.dst_engine_id = f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" + # Get or create sessions sessions = self.worker._get_built_session(task.dst_engine_id) - + # Prepare transfer plan plan = self._prepare_transfer_plan(task, request_info) - + # Execute transfer self._do_layer_write(plan, sessions) - + # Finalize if all layers complete self._finalize_if_complete(task, request_info) - + def _prepare_transfer_plan( - self, - task: WriteTask, - request_info: RemoteAllocInfo + self, task: WriteTask, request_info: RemoteAllocInfo ) -> LayerTransferPlan: """Prepare the transfer plan for a layer. - + Args: task: The write task request_info: Remote allocation information - + Returns: The transfer plan """ # Compute offsets if not cached if request_info.transfer_offset is None: offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, - task.local_block_ids, - request_info.block_ids + task.layer_name, task.local_block_ids, request_info.block_ids ) request_info.transfer_offset = offsets - + # Get session index - layer_names = list( - self.worker.layer_name_to_local_kv_cache_metadata.keys() - ) + layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys()) sess_idx = layer_names.index(task.layer_name) - + local_off, remote_off, sizes = request_info.transfer_offset - + return LayerTransferPlan( request_id=task.request_id, layer_name=task.layer_name, @@ -450,16 +453,12 @@ class MoRIIOWriter: transfer_local_offsets=local_off, transfer_remote_offsets=remote_off, transfer_sizes=sizes, - use_batch=True + use_batch=True, ) - - def _do_layer_write( - self, - plan: LayerTransferPlan, - sessions: list - ) -> None: + + def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None: """Perform the actual layer write. - + Args: plan: The transfer plan sessions: List of transfer sessions @@ -469,7 +468,7 @@ class MoRIIOWriter: plan.transfer_sizes, plan.transfer_local_offsets, plan.transfer_remote_offsets, - sessions[plan.sess_idx] + sessions[plan.sess_idx], ) else: for i in range(len(plan.transfer_local_offsets)): @@ -477,60 +476,59 @@ class MoRIIOWriter: plan.transfer_sizes[i], plan.transfer_local_offsets[i], plan.transfer_remote_offsets[i], - plan.sess_idx + plan.sess_idx, ) - + def _finalize_if_complete( - self, - task: WriteTask, - request_info: RemoteAllocInfo + self, task: WriteTask, request_info: RemoteAllocInfo ) -> None: """Finalize transfer if all layers are complete. - + Args: task: The write task request_info: Remote allocation information """ request_info.writes_done += 1 - + if request_info.writes_done >= self.worker.num_layers: # Wait for transfer to complete self.worker.moriio_wrapper.waiting_for_transfer_complete() - - + remote_port = task.remote_notify_port + get_port_offset( - request_info.decode_dp_rank, - self.worker.tp_rank + request_info.decode_dp_rank, self.worker.tp_rank ) # TODO: # Consider using RDMA immediate data in decode side to eliminate the need for this notification. # Consider including the first gen token from prefill in the notification - + # Send completion notification self.worker.moriio_wrapper.send_notify( - task.request_id, - task.remote_ip, - remote_port + task.request_id, task.remote_ip, remote_port ) - del self.worker.moriio_wrapper.done_remote_allocate_req_dict[task.request_id] + del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ + task.request_id + ] logger.debug( "Completed transfer for request %s, notified port %d", - task.request_id, remote_port + task.request_id, + remote_port, ) + + class MoRIIOWrapper: """Wrapper for MoRIIO engine operations. - + Handles both producer and consumer roles for KV cache transfers. - + Args: moriio_engine: MoRIIO engine instance tp_rank: Tensor parallel rank dp_rank: Data parallel rank """ - - def __init__(self, moriio_engine=None,tp_rank=0,dp_rank=0): - self.tp_rank=tp_rank - self.dp_rank=dp_rank + + def __init__(self, moriio_engine=None, tp_rank=0, dp_rank=0): + self.tp_rank = tp_rank + self.dp_rank = dp_rank self.moriio_engine = moriio_engine self.remote_memory_metadata = None self.local_memory_registered = False @@ -548,10 +546,11 @@ class MoRIIOWrapper: self.sessions = [] self.kv_caches = None self.paths = {} - def set_moriio_engine(self, moriio_engine): - assert moriio_engine is not None, "You Cannot pass None engine to MoRIIOWrapper!" + assert moriio_engine is not None, ( + "You Cannot pass None engine to MoRIIOWrapper!" + ) self.moriio_engine = moriio_engine def set_backend_type(self, backend_type): @@ -563,15 +562,15 @@ class MoRIIOWrapper: return engine_metadata_packed def register_remote_engine(self, remote_packed_engine_metadata): - consumer_engine_metadata = EngineDesc.unpack( - remote_packed_engine_metadata) + consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) self.moriio_engine.register_remote_engine(consumer_engine_metadata) return consumer_engine_metadata.key def register_local_tensor(self, tensor: torch.Tensor): try: self.local_memory_metadata = self.moriio_engine.register_torch_tensor( - tensor) + tensor + ) local_memory_metadata_packed = self.local_memory_metadata.pack() except Exception as e: raise MoRIIOError(f"Failed to register local memory: {e}") from e @@ -582,45 +581,47 @@ class MoRIIOWrapper: return MemoryDesc.unpack(packed_memory_metadata) def build_session(self, local_memory_metadata, remote_memory_metadata): - return self.moriio_engine.create_session(local_memory_metadata, - remote_memory_metadata) + return self.moriio_engine.create_session( + local_memory_metadata, remote_memory_metadata + ) - def read_remote_data(self, - transfer_size_byte, - local_offset=0, - remote_offset=0, - session=None): + def read_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): assert self.local_memory_registered, "You have not register local memory data!" transfer_status = session.batch_read( - local_offset, remote_offset, transfer_size_byte, - self.moriio_engine.allocate_transfer_uid()) + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) return transfer_status - def write_remote_data(self, - transfer_size_byte, - local_offset=0, - remote_offset=0, - session=None): + def write_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): assert self.local_memory_registered, "You have not register local memory data!" write_uid = self.moriio_engine.allocate_transfer_uid() - transfer_status = session.batch_write(local_offset, remote_offset, - transfer_size_byte, write_uid) + transfer_status = session.batch_write( + local_offset, remote_offset, transfer_size_byte, write_uid + ) with self.lock: self.transfer_status.append(transfer_status) - def write_remote_data_single(self, - transfer_size_byte, - local_offset=0, - remote_offset=0, - sess_idx=0): + def write_remote_data_single( + self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 + ): assert self.local_memory_registered, "You have not register local memory data!" transfer_status = self.sessions[sess_idx].write( - local_offset, remote_offset, transfer_size_byte, - self.moriio_engine.allocate_transfer_uid()) + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) with self.lock: self.transfer_status.append(transfer_status) @@ -640,14 +641,12 @@ class MoRIIOWrapper: logger.error( f"Transfer failed: {status.Message()}, Code: {status.Code()}" ) - raise TransferError(f"MoRIIO transfer failed!") + raise TransferError("MoRIIO transfer failed!") except Exception as e: logger.error(f"Transfer {status} failed: {e}") raise def async_wait_reqid(self): - - assert self.notify_port is not None, "Notify port cannot be None" if self.notify_thread is not None: @@ -668,9 +667,9 @@ class MoRIIOWrapper: raise HandshakeError(f"Error processing message: {e}") from e continue - self.notify_thread = threading.Thread(target=_async_wait, - daemon=True, - name="moriio-notify-listener") + self.notify_thread = threading.Thread( + target=_async_wait, daemon=True, name="moriio-notify-listener" + ) self.notify_thread.start() def _handle_message(self, msg: bytes): @@ -689,8 +688,7 @@ class MoRIIOWrapper: handled = True return - except (msgpack.exceptions.ExtraData, - msgpack.exceptions.UnpackException): + except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): pass try: @@ -706,11 +704,15 @@ class MoRIIOWrapper: def _handle_structured_message(self, data: dict): req_id = data["req_id"] block_notify_list = data.get("block_notify_list", []) - decode_dp_rank=data.get("decode_rank",0) - assert len(block_notify_list) > 0, "block_notify_list cannot be empty in remote allocate message" + decode_dp_rank = data.get("decode_rank", 0) + assert len(block_notify_list) > 0, ( + "block_notify_list cannot be empty in remote allocate message" + ) with self.lock: - self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(block_ids=block_notify_list,decode_dp_rank=decode_dp_rank) + self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo( + block_ids=block_notify_list, decode_dp_rank=decode_dp_rank + ) def _handle_completion_message(self, msg: str): with self.lock: @@ -728,10 +730,9 @@ class MoRIIOWrapper: if path not in self.paths: ctx = zmq.Context() - sock = make_zmq_socket(ctx=ctx, - path=path, - socket_type=zmq.DEALER, - bind=False) + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) self.paths[path] = sock req_list = req_ids if isinstance(req_ids, list) else [req_ids] @@ -740,8 +741,7 @@ class MoRIIOWrapper: try: for req_id in req_list: if not isinstance(req_id, str): - logger.warning( - f"Invalid req_id type: {type(req_id)}, expected str") + logger.warning(f"Invalid req_id type: {type(req_id)}, expected str") continue sock.send(req_id.encode("utf-8")) except Exception as e: @@ -764,12 +764,12 @@ class MoRIIOWrapper: return done_write_cache - class MoRIIOAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property.d - dict=True): + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True, +): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] @@ -781,6 +781,7 @@ class MoRIIOAgentMetadata( @dataclass class ReqMeta: """Metadata for a single request.""" + local_block_ids: list[int] remote_block_ids: list[int] remote_host: str @@ -793,7 +794,6 @@ class ReqMeta: class MoRIIOConnectorMetadata(KVConnectorMetadata): - def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} @@ -817,17 +817,16 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): kv_transfer_params: dict[str, Any], write_mode=False, ): - _req = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], - remote_handshake_port=kv_transfer_params['remote_handshake_port'], - remote_notify_port=kv_transfer_params.get('remote_notify_port'), + remote_handshake_port=kv_transfer_params["remote_handshake_port"], + remote_notify_port=kv_transfer_params.get("remote_notify_port"), tp_size=kv_transfer_params.get("tp_size", 1), - remote_dp_size=kv_transfer_params.get("remote_dp_size", 1) + remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), ) if write_mode: self.reqs_to_save[request_id] = _req @@ -836,43 +835,55 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): class MoRIIOConnector(KVConnectorBase_V1): - - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None,): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): assert vllm_config.kv_transfer_config is not None # assert vllm_config.kv_transfer_config.engine_id is not None - self.engine_id = str( - get_ip()) + ":" + str(vllm_config.kv_transfer_config. - kv_connector_extra_config['handshake_port']) + self.engine_id = ( + str(get_ip()) + + ":" + + str( + vllm_config.kv_transfer_config.kv_connector_extra_config[ + "handshake_port" + ] + ) + ) self.mode = get_moriio_mode() if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler: Optional[MoRIIOConnectorScheduler] = \ + self.connector_scheduler: MoRIIOConnectorScheduler | None = ( MoRIIOConnectorScheduler(vllm_config, self.engine_id) - self.connector_worker: Optional[MoRIIOConnectorWorker] = None + ) + self.connector_worker: MoRIIOConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = MoRIIOConnectorWorker( - vllm_config, self.engine_id) + self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) logger.info( f"Initialized MoRIIO Connector,engine_id: {self.engine_id},role: {role.value}" - ) + ) ############################################################ # Scheduler Side Methods ############################################################ def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens + ) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): assert self.connector_scheduler is not None return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens, self.connector_worker) + request, blocks, num_external_tokens, self.connector_worker + ) def build_connector_meta( self, @@ -885,7 +896,7 @@ class MoRIIOConnector(KVConnectorBase_V1): self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -896,15 +907,12 @@ class MoRIIOConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if self.mode == MoRIIOMode.WRITE: if get_role() == ROLE.CONSUMER: self.connector_worker.moriio_wrapper.async_wait_reqid() @@ -915,27 +923,32 @@ class MoRIIOConnector(KVConnectorBase_V1): def wait_for_layer_load(self, layer_name: str) -> None: pass - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: # Only producer/prefill saves KV Cache if get_role() == ROLE.CONSUMER: - return - self.connector_worker.save_kv_layer(self._connector_metadata, - layer_name, kv_layer, - attn_metadata, **kwargs) + return + self.connector_worker.save_kv_layer( + self._connector_metadata, layer_name, kv_layer, attn_metadata, **kwargs + ) return None def wait_for_save(self): pass + def has_connector_metadata(self) -> bool: """Check whether the connector metadata is currently set. Returns: bool: True if connector metadata exists, False otherwise. """ - try : + try: return self._connector_metadata is not None except AttributeError: return False @@ -949,17 +962,18 @@ class MoRIIOConnectorScheduler: self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - self.mode=get_moriio_mode() + self.mode = get_moriio_mode() - - - self.handeshake_port=self.vllm_config.kv_transfer_config.kv_connector_extra_config['handshake_port'] - logger.info( - f"==========> Initializing MoRIIO Scheduler {engine_id = }" + self.handeshake_port = ( + self.vllm_config.kv_transfer_config.kv_connector_extra_config[ + "handshake_port" + ] ) + logger.info(f"==========> Initializing MoRIIO Scheduler {engine_id = }") - self.side_notify_port = self.vllm_config.kv_transfer_config.kv_connector_extra_config[ - 'notify_port'] + self.side_notify_port = ( + self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] + ) self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" @@ -968,12 +982,11 @@ class MoRIIOConnectorScheduler: # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} - + # For chunked prefill, we perform layer-wise access within the final chunk. # TODO: Perform access at the end of each chunk. self._reqs_need_pending_save: dict[ReqId, tuple[Request, list[int]]] = {} - if self.is_producer: set_role(ROLE.PRODUCER) else: @@ -1006,7 +1019,6 @@ class MoRIIOConnectorScheduler: if self.is_producer: return 0, False - if self.mode == MoRIIOMode.WRITE: # MoriiO in write mode, no remote prefill @@ -1014,50 +1026,46 @@ class MoRIIOConnectorScheduler: return len(request.prompt_token_ids) - 1 - num_computed_tokens, False - def send_notify_block(self, - req_id: str, - block_notify_list: list[int] = None, - host=None, - port=None): - + def send_notify_block( + self, req_id: str, block_notify_list: list[int] = None, host=None, port=None + ): path = make_zmq_path("tcp", host, port) if path not in self.paths: ctx = zmq.Context() - sock = make_zmq_socket(ctx=ctx, - path=path, - socket_type=zmq.DEALER, - bind=False) + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) self.paths[path] = sock data = { "req_id": req_id, "block_notify_list": block_notify_list or [], "decode_rank": self.dp_rank, - "type": "remote_blocks" + "type": "remote_blocks", } # logger.debug(f"MoRIIO send notify block for prefill, {data= },{host= },{port= }") serialized_data = msgpack.dumps(data) self.paths[path].send(serialized_data) def update_state_after_alloc( - self, - request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int, - connector_worker: Optional["MoRIIOConnectorWorker"] = None): - + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + connector_worker: Optional["MoRIIOConnectorWorker"] = None, + ): params = request.kv_transfer_params if params.get("do_remote_decode"): local_block_ids = blocks.get_block_ids()[0] - self._reqs_need_save[request.request_id] = (request, - local_block_ids) + self._reqs_need_save[request.request_id] = (request, local_block_ids) if params is not None and params.get("do_remote_prefill"): if self.mode == MoRIIOMode.READ: if remote_block_ids := params.get("remote_block_ids"): - if all(p in params - for p in ("remote_engine_id", "remote_host", - "remote_port")): + if all( + p in params + for p in ("remote_engine_id", "remote_host", "remote_port") + ): # If remote_blocks and num_external_tokens = 0, we # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. @@ -1068,29 +1076,33 @@ class MoRIIOConnectorScheduler: if len(local_block_ids) == len(remote_block_ids): pass else: - local_block_ids = remote_block_ids[ - -len(local_block_ids):] + local_block_ids = remote_block_ids[-len(local_block_ids) :] self._reqs_need_recv[request.request_id] = ( - request, local_block_ids) + request, + local_block_ids, + ) else: logger.warning( "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", params) - + "request will not utilize KVTransfer", + params, + ) + else: - remote_dp_rank = request.kv_transfer_params.get('remote_dp_rank', 0) + remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) for tp_index in range(self.tp_size): - target_port = request.kv_transfer_params[ - 'remote_notify_port'] + get_port_offset(remote_dp_rank, tp_index) - + "remote_notify_port" + ] + get_port_offset(remote_dp_rank, tp_index) - self.send_notify_block(req_id=request.request_id, - block_notify_list=blocks.get_block_ids()[0], - host=params.get("remote_host"), - port=target_port) + self.send_notify_block( + req_id=request.request_id, + block_notify_list=blocks.get_block_ids()[0], + host=params.get("remote_host"), + port=target_port, + ) # Only trigger 1 KV transfer per request. @@ -1105,32 +1117,44 @@ class MoRIIOConnectorScheduler: if self.mode == MoRIIOMode.WRITE: # when async_load_kv finished, will add new reqs to scheduler_output.scheduled_new_reqs - if get_role()== ROLE.CONSUMER: + if get_role() == ROLE.CONSUMER: for new_req in scheduler_output.scheduled_new_reqs: red_id = new_req.req_id local_block_ids = list(new_req.block_ids) kv_transfer_params = new_req.sampling_params.extra_args[ - 'kv_transfer_params'] + "kv_transfer_params" + ] meta.add_new_req( red_id, local_block_ids, kv_transfer_params, ) - if get_role()== ROLE.PRODUCER: - # This is the logic for checking against chunked prefill. + if get_role() == ROLE.PRODUCER: + # This is the logic for checking against chunked prefill. # When the last chunk is identified, it places the request metadata into the saving queue. - - for i,req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids): - new_block_ids = scheduler_output.scheduled_cached_reqs.new_block_ids[i] - - if new_block_ids is not None: + + for i, req_id in enumerate( + scheduler_output.scheduled_cached_reqs.req_ids + ): + new_block_ids = ( + scheduler_output.scheduled_cached_reqs.new_block_ids[i] + ) + + if new_block_ids is not None: block_ids = new_block_ids[0] - + req, existing_blocks = self._reqs_need_pending_save[req_id] - updated_blocks = list(existing_blocks) + ([block_ids] if isinstance(block_ids, int) else block_ids) + updated_blocks = list(existing_blocks) + ( + [block_ids] if isinstance(block_ids, int) else block_ids + ) self._reqs_need_pending_save[req_id] = (req, updated_blocks) - if len(self._reqs_need_pending_save[req_id][1]*self.block_size)>=req.num_prompt_tokens: - + if ( + len( + self._reqs_need_pending_save[req_id][1] + * self.block_size + ) + >= req.num_prompt_tokens + ): meta.add_new_req( request_id=req_id, local_block_ids=self._reqs_need_pending_save[req_id][1], @@ -1138,8 +1162,7 @@ class MoRIIOConnectorScheduler: write_mode=True, ) del self._reqs_need_pending_save[req_id] - - + # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None @@ -1151,7 +1174,7 @@ class MoRIIOConnectorScheduler: for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - if req.num_prompt_tokens>len(block_ids): + if req.num_prompt_tokens > len(block_ids): # not last chunk prefill self._reqs_need_pending_save[req_id] = (req, block_ids) continue @@ -1175,7 +1198,7 @@ class MoRIIOConnectorScheduler: self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. @@ -1184,7 +1207,10 @@ class MoRIIOConnectorScheduler: params = request.kv_transfer_params logger.debug( "MoriioConnector request_finished, request_status=%s, " - "kv_transfer_params=%s", request.status, params) + "kv_transfer_params=%s", + request.status, + params, + ) if not params: return False, None @@ -1199,8 +1225,10 @@ class MoRIIOConnectorScheduler: params["do_remote_prefill"] = False return False, None - if (not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + if ( + not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED + ): return False, None # computed_block_ids = block_ids if all_full else block_ids[:-1] @@ -1210,9 +1238,10 @@ class MoRIIOConnectorScheduler: if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send[request.request_id] = time.perf_counter( - ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT - + self._reqs_need_send[request.request_id] = ( + time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + ) + # If we execute in P-D serial mode, no notification port is needed. return delay_free_blocks, dict( do_remote_prefill=True, @@ -1221,7 +1250,8 @@ class MoRIIOConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.handeshake_port, - tp_size=self.vllm_config.parallel_config.tensor_parallel_size) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + ) class MoRIIOConnectorWorker: @@ -1233,9 +1263,9 @@ class MoRIIOConnectorWorker: "MoRIIO is not available. Please ensure the 'mori' package " "is installed and properly configured." ) - + self.moriio_config = MoRIIOConfig.from_vllm_config(vllm_config) - self.mode=get_moriio_mode() + self.mode = get_moriio_mode() logger.info("Initializing MoRIIO worker %s", engine_id) # for debug @@ -1254,59 +1284,75 @@ class MoRIIOConnectorWorker: self._rank = get_world_group().rank self._local_rank = get_world_group().local_rank self.tp_rank = self.moriio_config.tp_rank - self.dp_rank= self.moriio_config.dp_rank - - logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }" - f",{self.is_producer= }") - - + self.dp_rank = self.moriio_config.dp_rank + + logger.info( + f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }" + f",{self.is_producer= }" + ) + self.local_ip = self.moriio_config.local_ip - self.local_kv_port=self.moriio_config.local_kv_port + self.local_kv_port = self.moriio_config.local_kv_port self.proxy_ip = self.moriio_config.proxy_ip self.proxy_port = self.moriio_config.proxy_port self.local_ping_port = self.moriio_config.local_ping_port - self.proxy_ping_port =self.moriio_config.proxy_ping_port + self.proxy_ping_port = self.moriio_config.proxy_ping_port self.http_port = self.moriio_config.http_port self.handshake_port = self.moriio_config.handshake_port self.notify_port = self.moriio_config.notify_port - + self.zmq_context = zmq.Context() - self.metadata_address = f"{self.moriio_config.local_ip}:{self.moriio_config.local_ping_port}" - self.request_address = f"{self.moriio_config.local_ip}:{self.moriio_config.http_port}" + self.metadata_address = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.local_ping_port}" + ) + self.request_address = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.http_port}" + ) self.moriio_engine = None self._handle_request_thread = None self._ping_thread = None self._writer = MoRIIOWriter(self) - - engine_suffix = (f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}" - f":tp {self.tp_rank}:dp {self.dp_rank}") + + engine_suffix = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}" + f":tp {self.tp_rank}:dp {self.dp_rank}" + ) if not self.is_producer: self.poller = zmq.Poller() self.metadata_socket = self.zmq_context.socket(zmq.ROUTER) self.metadata_socket.bind(f"tcp://{self.metadata_address}") self.poller.register(self.metadata_socket, zmq.POLLIN) - self.moriio_engine = IOEngine( "consumer:" + engine_suffix, - IOEngineConfig(self.moriio_config.local_ip, self.moriio_config.local_kv_port)) - + IOEngineConfig( + self.moriio_config.local_ip, self.moriio_config.local_kv_port + ), + ) + self._handle_request_thread = threading.Thread( - target=self.handle_proxy_request, daemon=True) + target=self.handle_proxy_request, daemon=True + ) self._handle_request_thread.start() else: - self.moriio_engine = IOEngine( "producer:" + engine_suffix, - IOEngineConfig(self.moriio_config.local_ip, self.moriio_config.local_kv_port)) - - logger.info("build IOEngine %s:%s", self.moriio_config.local_ip, self.moriio_config.local_kv_port) + IOEngineConfig( + self.moriio_config.local_ip, self.moriio_config.local_kv_port + ), + ) + + logger.info( + "build IOEngine %s:%s", + self.moriio_config.local_ip, + self.moriio_config.local_kv_port, + ) if self._rank == 0 and self.moriio_config.proxy_ip: - self._ping_thread = threading.Thread(target=self._ping, - args=(self.zmq_context, ), - daemon=True) + self._ping_thread = threading.Thread( + target=self._ping, args=(self.zmq_context,), daemon=True + ) self._ping_thread.start() logger.info( @@ -1316,24 +1362,23 @@ class MoRIIOConnectorWorker: f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.proxy_port = },{self.local_ping_port = },{self.proxy_ping_port = }" ) # Agent. - self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank,dp_rank=self.dp_rank) - + self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) + self.moriio_wrapper.set_moriio_engine(self.moriio_engine) self.moriio_wrapper.set_backend_type(BackendType.RDMA) self.moriio_wrapper.notify_port = self.moriio_config.notify_port - - + self.local_kv_cache_metadata = [] self.local_kv_cache_size = [] - self.layer_name_to_local_kv_cache_metadata: dict[str, - List[Any]] = dict() + self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() self.remote_kv_cache_metadata = [] self.remote_kv_cache_size = [] - self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[ - str, List[Any]]] = dict() + self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( + dict() + ) self.slot_size_bytes = 0 self.load_ready_flag = False @@ -1349,8 +1394,8 @@ class MoRIIOConnectorWorker: self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) self.side_channel_port: int = ( - self.moriio_config.handshake_port + - get_port_offset(self.dp_rank,self.tp_rank) + self.moriio_config.handshake_port + + get_port_offset(self.dp_rank, self.tp_rank) ) logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }") logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}, han") @@ -1371,25 +1416,24 @@ class MoRIIOConnectorWorker: self.num_regions = 0 self.num_layers = 0 - - # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. self.dst_num_blocks: dict[EngineId, int] = {} # In progress transfers. - self._recving_transfers:defaultdict[ReqId, list]=defaultdict(list) - self._recving_transfers_callback_addr: dict[ReqId, tuple[str,str]]= {} - + self._recving_transfers: defaultdict[ReqId, list] = defaultdict(list) + self._recving_transfers_callback_addr: dict[ReqId, tuple[str, str]] = {} + # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} # Background thread for handling new handshake requests. - self._moriio_handshake_listener_t: Optional[threading.Thread] = None + self._moriio_handshake_listener_t: threading.Thread | None = None # Background thread for initializing new MoRIIO handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # MoRIIO is not guaranteed to be thread-safe, limit 1 worker. max_workers=1, - thread_name_prefix="vllm-moriio-handshake-initiator") + thread_name_prefix="vllm-moriio-handshake-initiator", + ) self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} # Protects _handshake_futures and _remote_agents. @@ -1402,17 +1446,19 @@ class MoRIIOConnectorWorker: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) # List of block window sizes for each layer for local attention - self.block_window_per_layer: list[Optional[int]] = [] + self.block_window_per_layer: list[int | None] = [] self.use_mla = self.model_config.use_mla self.built_session = False self.built_write_session: defaultdict[str, list] = defaultdict(list) self._write_session_lock = threading.Lock() self.debug_cache = [] - backend = get_attn_backend(self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - use_mla=self.use_mla) + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) self.backend_name = backend.get_name() attn_backend = AttentionBackendEnum[self.backend_name] self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER @@ -1422,7 +1468,6 @@ class MoRIIOConnectorWorker: logger.debug("Detected attention backend %s", self.backend_name) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} - ####write worker### # self._write_task_q: Queue[WriteTask] = Queue() @@ -1431,21 +1476,19 @@ class MoRIIOConnectorWorker: # self._deferred_tasks: list[WriteTask] = [] # ####write worker### - - def schedule_write_blocks( self, request_id: str, dst_engine_id: str, local_block_ids: list[int], - remote_block_ids: Optional[list[int]], + remote_block_ids: list[int] | None, layer_name: str, kv_layer: torch.Tensor, remote_notify_port: int, - remote_ip: str + remote_ip: str, ) -> None: """Schedule a block write operation. - + Args: request_id: Unique identifier for the request dst_engine_id: Destination engine ID @@ -1461,50 +1504,53 @@ class MoRIIOConnectorWorker: event = torch.cuda.Event() event.record(stream) - task = WriteTask(request_id=request_id, - dst_engine_id=dst_engine_id, - local_block_ids=local_block_ids, - remote_block_ids_hint=remote_block_ids, - layer_name=layer_name, - event=event, - remote_notify_port=remote_notify_port, - remote_ip=remote_ip) + task = WriteTask( + request_id=request_id, + dst_engine_id=dst_engine_id, + local_block_ids=local_block_ids, + remote_block_ids_hint=remote_block_ids, + layer_name=layer_name, + event=event, + remote_notify_port=remote_notify_port, + remote_ip=remote_ip, + ) self._writer.schedule_write(task) - - def _get_built_session(self, remote_engine_id): if remote_engine_id not in self.built_write_session: cur_remote_engine_sessions = [] - for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items( - ): - - unpcaked_local_memory_meta = self.moriio_wrapper.get_unpack_memory_metadata( - local_meta[0]) - unpcaked_remote_memory_meta = self.moriio_wrapper.get_unpack_memory_metadata( - self.layer_name_to_remote_kv_cache_metadata[ - remote_engine_id][ln][0]) + for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items(): + unpcaked_local_memory_meta = ( + self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0]) + ) + unpcaked_remote_memory_meta = ( + self.moriio_wrapper.get_unpack_memory_metadata( + self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][ + ln + ][0] + ) + ) cur_remote_engine_sessions.append( self.moriio_wrapper.build_session( - unpcaked_local_memory_meta, - unpcaked_remote_memory_meta)) - self.built_write_session[ - remote_engine_id] = cur_remote_engine_sessions + unpcaked_local_memory_meta, unpcaked_remote_memory_meta + ) + ) + self.built_write_session[remote_engine_id] = cur_remote_engine_sessions return self.built_write_session[remote_engine_id] def _ping(self, zmq_context): PING_INTERVAL = 5 - MAX_RETRIES =100000 - + MAX_RETRIES = 100000 + http_request_address = f"http://{self.request_address}/v1/completions" role = "P" if self.is_producer else "D" - + retry_count = 0 index = 1 - + with zmq_context.socket(zmq.DEALER) as sock: sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}") - + while True: try: data = { @@ -1514,47 +1560,50 @@ class MoRIIOConnectorWorker: "request_address": http_request_address, "handshake_port": self.handshake_port, "notify_port": self.notify_port, - "dp_size":self.moriio_config.dp_size, - "tp_size":self.moriio_config.tp_size, - "transfer_mode":self.mode.name, + "dp_size": self.moriio_config.dp_size, + "tp_size": self.moriio_config.tp_size, + "transfer_mode": self.mode.name, } sock.send(msgpack.dumps(data)) # logger.debug(f"Successfully sent ping message #{index}") - retry_count = 0 - + retry_count = 0 + except ConnectionRefusedError: logger.info( f"Connection refused: {self.local_ip}:{self.local_ping_port} -> " f"{self.proxy_ip}:{self.proxy_ping_port}" ) retry_count += 1 - + except OSError as e: logger.info(f"OS error when sending ping: {e}") retry_count += 1 - + except Exception as e: logger.info(f"Unexpected error when sending ping: {e}") retry_count += 1 - + finally: if retry_count >= MAX_RETRIES: - logger.error(f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop.") + logger.error( + f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop." + ) break - + time.sleep(PING_INTERVAL) index += 1 def handle_proxy_request(self): if self.is_producer: raise NotImplementedError( - "prefill instance doesn't need to send kv cache in pull mode") + "prefill instance doesn't need to send kv cache in pull mode" + ) while True: socks = dict(self.poller.poll()) logger.debug(f"handle_proxy_request: {socks = }") - - #TODO: inkcherry , check here? + + # TODO: inkcherry , check here? if self.metadata_socket not in socks: continue else: @@ -1568,44 +1617,55 @@ class MoRIIOConnectorWorker: @staticmethod def _moriio_handshake_listener( - metadata: MoRIIOAgentMetadata, ready_event: threading.Event, - base_port: int, tp_rank: int,dp_rank:int, - layer_name_to_local_kv_cache_metadata: dict): + metadata: MoRIIOAgentMetadata, + ready_event: threading.Event, + base_port: int, + tp_rank: int, + dp_rank: int, + layer_name_to_local_kv_cache_metadata: dict, + ): """Background thread for getting new MoRIIO handshakes.""" encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) - logger.debug("Size of encoded MoRIIOAgentMetadata: %s bytes", - str(size_in_bytes)) + logger.debug( + "Size of encoded MoRIIOAgentMetadata: %s bytes", str(size_in_bytes) + ) # Listen for new requests for metadata. host = "*" - logger.info(f"======> mori handeshake starting listening on baseport: {base_port}") + logger.info( + f"======> mori handeshake starting listening on baseport: {base_port}" + ) - path = make_zmq_path("tcp", host, base_port ) + path = make_zmq_path("tcp", host, base_port) logger.info(f"======> mori handeshake sstarting listening on path: {path}") with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() while True: identity, msg = sock.recv_multipart() - if msg != MoRIIOConstants.GET_META_MSG and msg != MoRIIOConstants.POP_DONE_RECV: - logger.error( - "Connection listener got unexpected message %s", msg) + if ( + msg != MoRIIOConstants.GET_META_MSG + and msg != MoRIIOConstants.POP_DONE_RECV + ): + logger.error("Connection listener got unexpected message %s", msg) raise HandshakeError("handshake failed, unexpected msg type") elif msg == MoRIIOConstants.GET_META_MSG: sock.send_multipart( - (identity, b"", - encoded_data)) # send local mori io engine meta data + (identity, b"", encoded_data) + ) # send local mori io engine meta data logger.info("MoRIIO handshake listener sent metadata to %s") # now we send tensor meta data for each block buf = pickle.dumps(layer_name_to_local_kv_cache_metadata) sock.send_multipart((identity, b"", buf)) elif msg == MoRIIOConstants.POP_DONE_RECV: _, req_id = sock.recv_multipart() - logger.info("MoRIIO handshake listener received done recv for req %s", - req_id.decode()) + logger.info( + "MoRIIO handshake listener received done recv for req %s", + req_id.decode(), + ) else: pass @@ -1615,7 +1675,7 @@ class MoRIIOConnectorWorker: port: int, remote_tp_size: int, expected_engine_id: str, - remote_dp_rank:int=0, + remote_dp_rank: int = 0, ) -> dict[int, str]: """Do a MoRIIO handshake with a remote instance.""" @@ -1625,10 +1685,12 @@ class MoRIIOConnectorWorker: # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - - port_offset = get_port_offset(remote_dp_rank,self.tp_rank) + port_offset = get_port_offset(remote_dp_rank, self.tp_rank) path = make_zmq_path("tcp", host, port + port_offset) - logger.info("handeshake Querying metadata on path: %s at remote rank %s", path,) + logger.info( + "handeshake Querying metadata on path: %s at remote rank %s", + path, + ) # Send query for the request. with zmq_ctx(zmq.DEALER, path) as sock: @@ -1642,17 +1704,22 @@ class MoRIIOConnectorWorker: decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() - logger.info("MoRIIO handshake: get metadata took: %s", - got_metadata_time - start_time) + logger.info( + "MoRIIO handshake: get metadata took: %s", + got_metadata_time - start_time, + ) self.moriio_wrapper.remote_engine_ip = host remote_agent_name = self.moriio_wrapper.register_remote_engine( - metadata.agent_metadata) - remote_agent_name=EngineDesc.unpack(metadata.agent_metadata).key - - logger.info(f"MoRIIO handshake: registered remote agent " - f"{remote_agent_name=} for engine ID " - f"{expected_engine_id=},f{path= }") + metadata.agent_metadata + ) + remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key + + logger.info( + f"MoRIIO handshake: registered remote agent " + f"{remote_agent_name=} for engine ID " + f"{expected_engine_id=},f{path= }" + ) if len(self.local_kv_cache_metadata) > 0: logger.warning( f"{len(self.local_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" @@ -1668,18 +1735,21 @@ class MoRIIOConnectorWorker: if len(received_frame) != 2 or received_frame[0] != b"": assert 0, f"Unexpected frame! {received_frame = }" buf = received_frame[1] - self.layer_name_to_remote_kv_cache_metadata[ - expected_engine_id] = pickle.loads(buf) + self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( + pickle.loads(buf) + ) setup_agent_time = time.perf_counter() - logger.debug("MoRIIO handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + logger.debug( + "MoRIIO handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) return {remote_agent_name} - def _background_moriio_handshake(self, req_id: str, - remote_engine_id: EngineId, - meta: ReqMeta): + def _background_moriio_handshake( + self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta + ): # Do MoRIIO handshake in background and add to _ready_requests when done. fut = None if remote_engine_id is not None: @@ -1689,6 +1759,7 @@ class MoRIIOConnectorWorker: port = int(meta.remote_handshake_port) tp_size = int(meta.tp_size) remote_dp_size = int(meta.remote_dp_size) + # TODO: handle failure state of future in the # callback, we want to fail the request in this case. def request_ready(_f: Future[Any], entry=(req_id, meta)): @@ -1696,19 +1767,19 @@ class MoRIIOConnectorWorker: self._ready_requests.put(entry) self.load_ready_flag = True self.write_ready_flags[remote_engine_id] = True - + fut_list = [] - + # In dp(prefill)<->dp(decode) communication, we require an all-to-all handshake. for cur_dp_rank in range(remote_dp_size): dp_engine_id = f"{remote_engine_id}_dp{cur_dp_rank}" - + future = self._handshake_initiation_executor.submit( self._moriio_handshake, host, port, tp_size, dp_engine_id, cur_dp_rank ) fut_list.append(future) - + def done_callback(f: Future[dict[int, str]], eid=dp_engine_id): with self._handshake_lock: self._handshake_futures.pop(eid, None) @@ -1716,23 +1787,22 @@ class MoRIIOConnectorWorker: self._remote_agents[eid] = f.result() except Exception: logger.exception("Handshake with %s failed", eid) - + future.add_done_callback(done_callback) self._handshake_futures[dp_engine_id] = future - + # fut = fut_list def wait_all_dp(): for future in fut_list: - future.result() + future.result() return True all_done_future = self._handshake_initiation_executor.submit(wait_all_dp) all_done_future.add_done_callback(request_ready) - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in moriio.""" - + # kv_caches,KEY layer name,VALUE cache tensor,(2,numblocks,blocksize,headnum,headsize) _, first_kv_cache = next(iter(kv_caches.items())) kv_elem_size = first_kv_cache.element_size() @@ -1759,7 +1829,9 @@ class MoRIIOConnectorWorker: block_shape = first_kv_cache.shape[-block_rank:] block_size, n_kv_heads, head_dim = block_shape[-3:] # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim # 1 token 1 layer size , slot size + self.slot_size_bytes = ( + kv_elem_size * n_kv_heads * head_dim + ) # 1 token 1 layer size , slot size assert block_size == self.block_size # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc @@ -1784,35 +1856,32 @@ class MoRIIOConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - - for cache_or_caches in kv_caches.values(): - cache_list = [ - cache_or_caches - ] if use_mla or self._use_flashinfer else cache_or_caches + for cache_or_caches in kv_caches.values(): + cache_list = ( + [cache_or_caches] + if use_mla or self._use_flashinfer + else cache_or_caches + ) # logger.debug(f"prepare register local kv cache tensor for local mori io engine,{len(cache_list) = },{kv_caches.keys() = }") for cache in cache_list: - base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len - caches_data.append( - (base_addr, region_len, cache.device.index, "")) + caches_data.append((base_addr, region_len, cache.device.index, "")) kv_caches_base_addr.append(base_addr) for layer_name, kv_cache in kv_caches.items(): - if layer_name not in self.layer_name_to_local_kv_cache_metadata: self.layer_name_to_local_kv_cache_metadata[layer_name] = [] # for cache in cache_list: # moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache) - moriio_mem_metadata = self.moriio_wrapper.register_local_tensor( - kv_cache) + moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache) self.layer_name_to_local_kv_cache_metadata[layer_name].append( - moriio_mem_metadata) + moriio_mem_metadata + ) - self.local_kv_cache_size.append(cache.nelement() * - cache.element_size()) + self.local_kv_cache_size.append(cache.nelement() * cache.element_size()) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) @@ -1821,8 +1890,10 @@ class MoRIIOConnectorWorker: # Optimization for models with local attention (Llama 4) if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) + + assert isinstance( + self.vllm_config.model_config.hf_text_config, Llama4TextConfig + ) llama4_config = self.vllm_config.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size @@ -1833,8 +1904,10 @@ class MoRIIOConnectorWorker: is_local_attention = no_rope_layers[layer_idx] != 0 block_window = chunk_block_size if is_local_attention else None self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) + logger.debug( + "Llama 4 block window per layer mapping: %s", + self.block_window_per_layer, + ) assert len(self.block_window_per_layer) == self.num_layers metadata = MoRIIOAgentMetadata( @@ -1843,14 +1916,22 @@ class MoRIIOConnectorWorker: kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, block_len=self.block_len, - attn_backend_name=self.backend_name) + attn_backend_name=self.backend_name, + ) ready_event = threading.Event() self._moriio_handshake_listener_t = threading.Thread( target=self._moriio_handshake_listener, - args=(metadata, ready_event, self.side_channel_port, self.tp_rank,self.dp_rank, - self.layer_name_to_local_kv_cache_metadata), + args=( + metadata, + ready_event, + self.side_channel_port, + self.tp_rank, + self.dp_rank, + self.layer_name_to_local_kv_cache_metadata, + ), daemon=True, - name="moriio_handshake_listener") + name="moriio_handshake_listener", + ) self._moriio_handshake_listener_t.start() ready_event.wait() # Wait for listener ZMQ socket to be ready. self.moriio_wrapper.async_wait_reqid() @@ -1869,49 +1950,54 @@ class MoRIIOConnectorWorker: if self.mode == MoRIIOMode.WRITE: done_recving = set() else: - done_recving=self._pop_done_transfers() + done_recving = self._pop_done_transfers() else: if self.mode == MoRIIOMode.WRITE: self.moriio_wrapper.async_wait_reqid() - done_sending, done_recving = set( - ), self.moriio_wrapper.pop_finished_write_req_ids() + done_sending, done_recving = ( + set(), + self.moriio_wrapper.pop_finished_write_req_ids(), + ) return done_sending, done_recving - def _pop_done_transfers(self) -> set[str]: - done_req_ids: set[str] = set() for req_id, status_list in self._recving_transfers.items(): if status_list[-1].Succeeded(): done_req_ids.add(req_id) - + self.moriio_wrapper.send_notify( - req_id,self._recving_transfers_callback_addr[req_id][0], - self._recving_transfers_callback_addr[req_id][1]) + req_id, + self._recving_transfers_callback_addr[req_id][0], + self._recving_transfers_callback_addr[req_id][1], + ) del self._recving_transfers[req_id] del self._recving_transfers_callback_addr[req_id] - + return done_req_ids - - def save_kv_layer(self, metadata: MoRIIOConnectorMetadata, layer_name: str, - kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs): - + def save_kv_layer( + self, + metadata: MoRIIOConnectorMetadata, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ): if not self.is_producer: return if self.mode == MoRIIOMode.READ: return remote_engine_id = None - - + for req_id, meta in metadata.reqs_to_save.items(): remote_engine_id = meta.remote_engine_id # we only need to check if dp0 in rank - remote_engine_id = str(meta.remote_host) + ":" + str( - meta.remote_handshake_port) - + remote_engine_id = ( + str(meta.remote_host) + ":" + str(meta.remote_handshake_port) + ) + meta.remote_engine_id = remote_engine_id # TODO: mz get_remote_engine_id() for engine_id mapping. @@ -1920,10 +2006,10 @@ class MoRIIOConnectorWorker: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - logger.info( - f"*****background moriio {remote_engine_id = }") + logger.info(f"*****background moriio {remote_engine_id = }") self._background_moriio_handshake( - req_id, remote_engine_id, meta) + req_id, remote_engine_id, meta + ) continue self._write_blocks_for_req(req_id, meta, layer_name, kv_layer) @@ -1931,12 +2017,17 @@ class MoRIIOConnectorWorker: while True: if remote_engine_id is None: break - if self._ready_requests.empty() and remote_engine_id not in self.write_ready_flags: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.write_ready_flags + ): continue - elif not self._ready_requests.empty() and (remote_engine_id - in self.write_ready_flags): - self._write_blocks_for_req(*self._ready_requests.get_nowait(), - layer_name, kv_layer) + elif not self._ready_requests.empty() and ( + remote_engine_id in self.write_ready_flags + ): + self._write_blocks_for_req( + *self._ready_requests.get_nowait(), layer_name, kv_layer + ) break else: break @@ -1957,8 +2048,9 @@ class MoRIIOConnectorWorker: remote_engine_id = None for req_id, meta in metadata.reqs_to_recv.items(): - remote_engine_id = str(meta.remote_host) + ":" + str( - meta.remote_handshake_port) + remote_engine_id = ( + str(meta.remote_host) + ":" + str(meta.remote_handshake_port) + ) meta.remote_engine_id = remote_engine_id dp0_remote_engine_id = f"{remote_engine_id}_dp0" if dp0_remote_engine_id not in self._remote_agents: @@ -1966,7 +2058,8 @@ class MoRIIOConnectorWorker: with self._handshake_lock: if remote_engine_id not in self._remote_agents: self._background_moriio_handshake( - req_id, remote_engine_id, meta) + req_id, remote_engine_id, meta + ) wait_handshage_readd_req = True continue @@ -1975,9 +2068,12 @@ class MoRIIOConnectorWorker: self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. - while True: #TODO - if self._ready_requests.empty( - ) and not self.load_ready_flag and wait_handshage_readd_req: + while True: # TODO + if ( + self._ready_requests.empty() + and not self.load_ready_flag + and wait_handshage_readd_req + ): continue elif not self._ready_requests.empty() and self.load_ready_flag: self._read_blocks_for_req(*self._ready_requests.get_nowait()) @@ -1986,12 +2082,13 @@ class MoRIIOConnectorWorker: break self._reqs_to_send.update(metadata.reqs_to_send) - def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, req_id) + meta.remote_engine_id, + req_id, + ) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -2001,18 +2098,19 @@ class MoRIIOConnectorWorker: remote_notify_port=meta.remote_notify_port, ) - def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, - kv_layer): + def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer): # logger.debug(f"write block for req {req_id} to remote engine " # f"{meta.remote_engine_id}") - self.schedule_write_blocks(request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - layer_name=layer_name, - kv_layer=kv_layer, - remote_notify_port=meta.remote_notify_port, - remote_ip=meta.remote_host) + self.schedule_write_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + layer_name=layer_name, + kv_layer=kv_layer, + remote_notify_port=meta.remote_notify_port, + remote_ip=meta.remote_host, + ) def _is_last_layer(self, layer_name): if layer_name == list(self.kv_caches.keys())[-1]: @@ -2025,12 +2123,12 @@ class MoRIIOConnectorWorker: return False def merge_contiguous_blocks( - self, - offsets_local: List[int], - offsets_remote: List[int], - sizes: List[int], - assume_sorted: bool = False - ) -> Tuple[List[int], List[int], List[int]]: + self, + offsets_local: list[int], + offsets_remote: list[int], + sizes: list[int], + assume_sorted: bool = False, + ) -> tuple[list[int], list[int], list[int]]: n = len(offsets_local) if n == 0: return [], [], [] @@ -2056,8 +2154,11 @@ class MoRIIOConnectorWorker: sizes_sorted = sizes_arr[sort_idx] if n == 1: - return [int(local_sorted[0])], [int(remote_sorted[0]) - ], [int(sizes_sorted[0])] + return ( + [int(local_sorted[0])], + [int(remote_sorted[0])], + [int(sizes_sorted[0])], + ) diff_local = local_sorted[1:] - local_sorted[:-1] diff_remote = remote_sorted[1:] - remote_sorted[:-1] @@ -2066,13 +2167,11 @@ class MoRIIOConnectorWorker: contiguous = (diff_local == prev_size) & (diff_remote == prev_size) if not contiguous.any(): - return local_sorted.tolist(), remote_sorted.tolist( - ), sizes_sorted.tolist() + return local_sorted.tolist(), remote_sorted.tolist(), sizes_sorted.tolist() if contiguous.all(): total_size = int(sizes_sorted.sum()) - return [int(local_sorted[0])], [int(remote_sorted[0]) - ], [total_size] + return [int(local_sorted[0])], [int(remote_sorted[0])], [total_size] break_positions = np.flatnonzero(~contiguous) + 1 segment_starts = np.concatenate(([0], break_positions)) @@ -2089,8 +2188,9 @@ class MoRIIOConnectorWorker: merged_local[si] = int(local_sorted[s]) merged_remote[si] = int(remote_sorted[s]) - merged_sizes[si] = int(local_sorted[e - 1] + sizes_sorted[e - 1] - - local_sorted[s]) + merged_sizes[si] = int( + local_sorted[e - 1] + sizes_sorted[e - 1] - local_sorted[s] + ) return merged_local, merged_remote, merged_sizes @@ -2099,18 +2199,18 @@ class MoRIIOConnectorWorker: layer_name: str, local_block_ids: list[int], remote_block_ids: list[int], - ) -> tuple[list[int], list[int], list[int]]: + ) -> tuple[list[int], list[int], list[int]]: """Compute transfer offsets for block data. - + Args: layer_name: Name of the layer to transfer local_block_ids: IDs of local blocks remote_block_ids: IDs of remote blocks - + Returns: Tuple of (local_offsets, remote_offsets, transfer_sizes) """ - is_mla = (len(self.kv_cache_shape) == 3) + is_mla = len(self.kv_cache_shape) == 3 stride = self.kv_caches[layer_name].stride() sz = self.kv_caches[layer_name].element_size() if is_mla: @@ -2144,32 +2244,46 @@ class MoRIIOConnectorWorker: w += 1 merged_l, merged_r, merged_s = self.merge_contiguous_blocks( - offset_local, offset_remote, sizes, assume_sorted=True) + offset_local, offset_remote, sizes, assume_sorted=True + ) return merged_l, merged_r, merged_s - def _read_blocks(self, local_block_ids: list[int], - remote_block_ids: list[int], dst_engine_id: str, - request_id: str, - remote_host: str, - remote_notify_port: int)-> None: + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + remote_host: str, + remote_notify_port: int, + ) -> None: if self.mode == MoRIIOMode.WRITE: return # we only test TP<->TP in read mode # assert self.dp_rank>0, "only test TP<->TP in read mode" - dst_engine_id+="_dp0" + dst_engine_id += "_dp0" sessions = self._get_built_session(dst_engine_id) - + first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] - offs = self._compute_block_transfer_offsets(first_layer, local_block_ids, remote_block_ids) + offs = self._compute_block_transfer_offsets( + first_layer, local_block_ids, remote_block_ids + ) a, b, c = offs[0], offs[1], offs[2] for layer_name in self.layer_name_to_local_kv_cache_metadata.keys(): - sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(layer_name) - transfer_status=self.moriio_wrapper.read_remote_data(c, a, b, sessions[sess_idx]) - + sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( + layer_name + ) + transfer_status = self.moriio_wrapper.read_remote_data( + c, a, b, sessions[sess_idx] + ) + self._recving_transfers[request_id].append(transfer_status) - self._recving_transfers_callback_addr[request_id]=(remote_host,remote_notify_port + self.tp_rank) - + self._recving_transfers_callback_addr[request_id] = ( + remote_host, + remote_notify_port + self.tp_rank, + ) + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: @@ -2178,13 +2292,12 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): raise ValueError(f"Unexpected socket type: {socket_type}") - ctx: Optional[zmq.Context] = None + ctx: zmq.Context | None = None try: ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) finally: if ctx is not None: - ctx.destroy(linger=0) \ No newline at end of file + ctx.destroy(linger=0)