diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 9cefa0cd835d9..931462f3c3dde 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -37,7 +37,7 @@ from vllm.distributed.parallel_state import ( ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket, get_open_port from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -69,7 +69,8 @@ class MoRIIOConstants: PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 - + DEFAULT_HANDSHAKE_PORT = "6301" + DEFAULT_NOTIFY_PORT="61005" try: from mori.io import ( @@ -78,6 +79,8 @@ try: IOEngine, IOEngineConfig, MemoryDesc, + PollCqMode, + RdmaBackendConfig ) logger.info("MoRIIO is available") @@ -192,11 +195,12 @@ class TransferError(MoRIIOError): def get_moriio_mode() -> MoRIIOMode: read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower() - # logger.info(f"MoRIIO Connector Read Mode = {read_mode}") + logger.debug("MoRIIO Connector read_mode: %s", read_mode) if read_mode in ("true", "1", "yes", "on"): return MoRIIOMode.READ else: return MoRIIOMode.WRITE + def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: @@ -220,19 +224,21 @@ class MoRIIOConfig: @classmethod def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": + # Port Configuration: - # local_ping_port -> Outgoing heartbeat to proxy(only rank0 need it) + # local_ping_port -> Outgoing heartbeat to proxy # proxy_ping_port -> Remote proxy's heartbeat ingress port # http_port -> Instance's HTTP service endpoint - # local_kv_port -> KV service port for Mori engine - # notify_port -> For synchronizing stages between nodes + # local_kv_port -> service port for mori engine + # notify_port -> For synchronizing stages between prefill and decode + # handshake_port -> For initial handshake between mori engine + #TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports + kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() dp_rank = vllm_config.parallel_config.data_parallel_rank - base_kv_port = int(kv_transfer_config.kv_port) - base_ping_port = int(extra_config["local_ping_port"]) base_notify_port = int(extra_config["notify_port"]) dp_size = vllm_config.parallel_config.data_parallel_size tp_size = get_tensor_model_parallel_world_size() @@ -240,9 +246,9 @@ class MoRIIOConfig: return cls( local_ip=get_ip(), - local_kv_port=base_kv_port + port_offset, + local_kv_port=get_open_port(), proxy_ip=extra_config["proxy_ip"], - local_ping_port=base_ping_port + port_offset, + local_ping_port=get_open_port(), proxy_ping_port=int(extra_config["proxy_ping_port"]), http_port=int(extra_config["http_port"]), handshake_port=int(extra_config["handshake_port"]), @@ -545,7 +551,7 @@ class MoRIIOWrapper: self.notify_thread = None self.sock = None self.sessions: list[IOEngine.Session] = [] - self.paths = {} + self.paths: dict[str, zmq.Socket] = {} def set_moriio_engine(self, moriio_engine): assert moriio_engine is not None, ( @@ -554,7 +560,17 @@ class MoRIIOWrapper: self.moriio_engine = moriio_engine def set_backend_type(self, backend_type): - self.moriio_engine.create_backend(backend_type) + qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1")) + post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1")) + num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1")) + poll_mode = PollCqMode.POLLING + rdma_cfg = RdmaBackendConfig( + qp_per_transfer, + post_batch_size, + num_worker_threads, + poll_mode, + ) + self.moriio_engine.create_backend(backend_type, rdma_cfg) def get_agent_metadata(self): engine_metadata = self.moriio_engine.get_engine_desc() @@ -700,6 +716,8 @@ class MoRIIOWrapper: raise MoRIIOError(f"Unhandled message format: {msg_str}") def _handle_structured_message(self, data: dict): + + assert get_role()==ROLE.PRODUCER, "Only prefill can get block messages" req_id = data["req_id"] block_notify_list = data.get("block_notify_list", []) decode_dp_rank = data.get("decode_rank", 0) @@ -882,16 +900,12 @@ class MoRIIOConnector(KVConnectorBase_V1): extra_config = kv_transfer_config.kv_connector_extra_config if "handshake_port" not in extra_config or not extra_config["handshake_port"]: - extra_config["handshake_port"] = "6301" + extra_config["handshake_port"] = MoRIIOConstants.DEFAULT_HANDSHAKE_PORT if "notify_port" not in extra_config or not extra_config["notify_port"]: - extra_config["notify_port"] = "61005" + extra_config["notify_port"] = MoRIIOConstants.DEFAULT_NOTIFY_PORT - if "local_ping_port" not in extra_config or not extra_config["local_ping_port"]: - extra_config["local_ping_port"] = "7583" - if not kv_transfer_config.kv_port: - kv_transfer_config.kv_port = "7305" def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -1180,7 +1194,7 @@ class MoRIIOConnectorScheduler: if new_block_ids is not None: block_ids = new_block_ids[0] - + #TODO : hybrid attn, etc req, existing_blocks = self._reqs_need_pending_save[req_id] updated_blocks = list(existing_blocks) + (block_ids) self._reqs_need_pending_save[req_id] = (req, updated_blocks) @@ -2261,6 +2275,7 @@ class MoRIIOConnectorWorker: sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( layer_name ) + #TODO : apply multi-session batch-read when moriio support it transfer_status = self.moriio_wrapper.read_remote_data( offs[2], offs[0], offs[1], sessions[sess_idx] )