Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 07:31:29 +00:00
parent 4f592ae696
commit b60ee86585

View File

@ -1,19 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import logging
import math
import os
import pickle
import queue
import threading
import time
import msgpack
import msgspec
import numpy as np
import torch
import zmq
import logging
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
@ -21,6 +15,12 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from weakref import ref as weakref_ref
import msgpack
import msgspec
import numpy as np
import torch
import zmq
from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
@ -68,10 +68,10 @@ class MoRIIOConstants:
OVER = b"OVER"
COMPLETION_PREFIX = "cmpl"
PING_INTERVAL = 5
MAX_PING_RETRIES = 100000
try:
import mori
from mori.io import (
@ -202,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode:
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
return ((dp_rank) * tp_size + tp_rank)
return (dp_rank) * tp_size + tp_rank
@dataclass
@ -316,7 +316,6 @@ class MoRIIOWriter:
def _write_worker_loop(self) -> None:
"""Main loop for the write worker thread."""
logger.info("Write worker loop started")
while True:
# Process deferred tasks first
@ -408,7 +407,9 @@ class MoRIIOWriter:
task.event.synchronize()
# Update engine ID with DP rank
task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank)
task.dst_engine_id = self.worker.get_engine_name_with_dp(
task.dst_engine_id, request_info.decode_dp_rank
)
# Get or create sessions
sessions = self.worker._get_built_session(task.dst_engine_id)
@ -544,7 +545,7 @@ class MoRIIOWrapper:
self.done_write_cache_req_ids = []
self.notify_thread = None
self.sock = None
self.sessions: list["IOEngine.Session"] = []
self.sessions: list[IOEngine.Session] = []
self.paths = {}
def set_moriio_engine(self, moriio_engine):
@ -968,7 +969,7 @@ class MoRIIOConnectorScheduler:
"handshake_port"
]
)
logger.info(f"==========> Initializing MoRIIO Scheduler {engine_id = }")
logger.info(f"Initializing MoRIIO Scheduler {engine_id = }")
self.side_notify_port = (
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"]
@ -1149,7 +1150,7 @@ class MoRIIOConnectorScheduler:
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
if (
len(self._reqs_need_pending_save[req_id][1])
* self.block_size
* self.block_size
>= req.num_prompt_tokens
):
meta.add_new_req(
@ -1171,7 +1172,7 @@ class MoRIIOConnectorScheduler:
for req_id, (req, block_ids) in self._reqs_need_save.items():
assert req.kv_transfer_params is not None
if req.num_prompt_tokens > len(block_ids)*self.block_size:
if req.num_prompt_tokens > len(block_ids) * self.block_size:
# not last chunk prefill
self._reqs_need_pending_save[req_id] = (req, block_ids)
continue
@ -1265,7 +1266,7 @@ class MoRIIOConnectorWorker:
self.mode = get_moriio_mode()
logger.info("Initializing MoRIIO worker %s", engine_id)
logging.getLogger("aiter").disabled = True
# Config.
@ -1283,11 +1284,6 @@ class MoRIIOConnectorWorker:
self.tp_rank = self.moriio_config.tp_rank
self.dp_rank = self.moriio_config.dp_rank
logger.info(
f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }"
f",{self.is_producer= }"
)
self.local_ip = self.moriio_config.local_ip
self.local_kv_port = self.moriio_config.local_kv_port
self.proxy_ip = self.moriio_config.proxy_ip
@ -1340,8 +1336,8 @@ class MoRIIOConnectorWorker:
),
)
logger.info(
"build IOEngine %s:%s",
logger.debug(
"build MORI IOEngine %s:%s",
self.moriio_config.local_ip,
self.moriio_config.local_kv_port,
)
@ -1390,8 +1386,6 @@ class MoRIIOConnectorWorker:
self.moriio_config.handshake_port
+ get_port_offset(self.dp_rank, self.tp_rank)
)
logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }")
logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}")
self.engine_id: EngineId = engine_id
self.world_size = get_tensor_model_parallel_world_size()
@ -1522,7 +1516,6 @@ class MoRIIOConnectorWorker:
return self.built_write_session[remote_engine_id]
def _ping(self, zmq_context):
http_request_address = f"http://{self.request_address}/v1/completions"
role = "P" if self.is_producer else "D"
@ -1586,15 +1579,18 @@ class MoRIIOConnectorWorker:
if self.metadata_socket not in socks:
continue
def close(self):
if hasattr(self, '_handshake_initiation_executor'):
if hasattr(self, "_handshake_initiation_executor"):
self._handshake_initiation_executor.shutdown(wait=False)
if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t:
if (
hasattr(self, "_moriio_handshake_listener_t")
and self._moriio_handshake_listener_t
):
self._moriio_handshake_listener_t.join(timeout=0)
if hasattr(self, 'zmq_context') and self.zmq_context:
if hasattr(self, "zmq_context") and self.zmq_context:
self.zmq_context.destroy(linger=0)
self.zmq_context = None
@ -1621,12 +1617,9 @@ class MoRIIOConnectorWorker:
# Listen for new requests for metadata.
host = "*"
logger.info(
f"======> mori handshake starting listening on baseport: {base_port}"
)
path = make_zmq_path("tcp", host, base_port)
logger.info(f"======> mori handshake starting listening on path: {path}")
logger.debug(f" mori handshake starting listening on path: {path}")
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
@ -1636,20 +1629,20 @@ class MoRIIOConnectorWorker:
msg != MoRIIOConstants.GET_META_MSG
and msg != MoRIIOConstants.POP_DONE_RECV
):
logger.error("Connection listener got unexpected message %s", msg)
logger.error("Connection listener got unexpected message")
raise HandshakeError("handshake failed, unexpected msg type")
elif msg == MoRIIOConstants.GET_META_MSG:
sock.send_multipart(
(identity, b"", encoded_data)
) # send local mori io engine meta data
logger.info("MoRIIO handshake listener sent metadata to %s")
logger.debug("MoRIIO handshake listener sent metadata")
# now we send tensor meta data for each block
buf = pickle.dumps(layer_name_to_local_kv_cache_metadata)
sock.send_multipart((identity, b"", buf))
elif msg == MoRIIOConstants.POP_DONE_RECV:
_, req_id = sock.recv_multipart()
logger.info(
"MoRIIO handshake listener received done recv for req %s",
logger.debug(
"MoRIIO handshake listener received done recv for req",
req_id.decode(),
)
@ -1671,10 +1664,7 @@ class MoRIIOConnectorWorker:
port_offset = get_port_offset(remote_dp_rank, self.tp_rank)
path = make_zmq_path("tcp", host, port + port_offset)
logger.info(
"handshake Querying metadata on path: %s at remote rank %s",
path,
)
logger.debug(f"handshake Querying metadata on path:{path}")
# Send query for the request.
with zmq_ctx(zmq.DEALER, path) as sock:
@ -1694,7 +1684,7 @@ class MoRIIOConnectorWorker:
)
self.moriio_wrapper.remote_engine_ip = host
remote_agent_name=self.moriio_wrapper.register_remote_engine(
remote_agent_name = self.moriio_wrapper.register_remote_engine(
metadata.agent_metadata
)
@ -1748,7 +1738,7 @@ class MoRIIOConnectorWorker:
def request_ready(_f: Future[Any], entry=(req_id, meta)):
logger.info("MoRIIO handshake done for request %s", req_id)
self._ready_requests.put(entry)
self.load_ready_flag [remote_engine_id] = True
self.load_ready_flag[remote_engine_id] = True
self.write_ready_flags[remote_engine_id] = True
fut_list = []
@ -1999,10 +1989,10 @@ class MoRIIOConnectorWorker:
break
else:
break
def get_engine_name_with_dp(self, engine_name,dp_rank):
def get_engine_name_with_dp(self, engine_name, dp_rank):
return f"{engine_name}_dp{dp_rank}"
def start_load_kv(self, metadata: MoRIIOConnectorMetadata):
"""
Start loading by triggering non-blocking moriio_xfer.
@ -2022,7 +2012,7 @@ class MoRIIOConnectorWorker:
str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
)
meta.remote_engine_id = remote_engine_id
dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0)
dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id, 0)
if dp0_remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock:
@ -2038,20 +2028,21 @@ class MoRIIOConnectorWorker:
self._read_blocks_for_req(req_id, meta)
# Start transfers for requests whose handshakes have now finished.
while True:
while True:
if (
self._ready_requests.empty()
and remote_engine_id not in self.load_ready_flag
and wait_handshake_readd_req
):
continue
elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag:
elif (
not self._ready_requests.empty()
and remote_engine_id in self.load_ready_flag
):
self._read_blocks_for_req(*self._ready_requests.get_nowait())
break
else:
break
self._reqs_to_send.update(metadata.reqs_to_send)
@ -2089,11 +2080,6 @@ class MoRIIOConnectorWorker:
return True
return False
def _is_first_layer(self, layer_name):
if layer_name == list(self.kv_caches.keys())[0]:
return True
return False
def merge_contiguous_blocks(
self,
offsets_local: list[int],
@ -2230,9 +2216,8 @@ class MoRIIOConnectorWorker:
) -> None:
if self.mode == MoRIIOMode.WRITE:
return
dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0)
dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0)
sessions = self._get_built_session(dp0_engine_id)
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]