refine code

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 04:36:19 +00:00
parent 68a2333339
commit 70ea1b2460

View File

@ -7,6 +7,13 @@ import pickle
import queue
import threading
import time
import msgpack
import msgspec
import numpy as np
import torch
import zmq
import logging
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
@ -14,12 +21,6 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from weakref import ref as weakref_ref
import msgpack
import msgspec
import numpy as np
import torch
import zmq
from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
@ -50,7 +51,6 @@ if TYPE_CHECKING:
from dataclasses import field
from enum import Enum
from queue import Empty, Queue
import logging
logger = init_logger(__name__)
@ -70,7 +70,9 @@ class MoRIIOConstants:
# Default GPU count per node for standard configurations
RANK_PER_NODE = 8
PING_INTERVAL = 5
MAX_PING_RETRIES = 100000
try:
import mori
@ -80,7 +82,6 @@ try:
IOEngine,
IOEngineConfig,
MemoryDesc,
StatusCode,
)
logger.info("MoRIIO is available")
@ -404,10 +405,12 @@ class MoRIIOWriter:
return
# Wait for CUDA event
#The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
#This event is used to synchronize the kv transfer and computation tasks.
task.event.synchronize()
# Update engine ID with DP rank
task.dst_engine_id = f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}"
task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank)
# Get or create sessions
sessions = self.worker._get_built_session(task.dst_engine_id)
@ -665,7 +668,6 @@ class MoRIIOWrapper:
except Exception as e:
logger.error(f"Error processing message: {e}")
raise HandshakeError(f"Error processing message: {e}") from e
continue
self.notify_thread = threading.Thread(
target=_async_wait, daemon=True, name="moriio-notify-listener"
@ -685,7 +687,6 @@ class MoRIIOWrapper:
data = msgpack.loads(msg)
if isinstance(data, dict) and "req_id" in data:
self._handle_structured_message(data)
handled = True
return
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
@ -841,6 +842,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role)
assert vllm_config.kv_transfer_config is not None
# assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id = (
@ -1149,10 +1151,8 @@ class MoRIIOConnectorScheduler:
)
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
if (
len(
self._reqs_need_pending_save[req_id][1]
len(self._reqs_need_pending_save[req_id][1])
* self.block_size
)
>= req.num_prompt_tokens
):
meta.add_new_req(
@ -1363,13 +1363,9 @@ class MoRIIOConnectorWorker:
)
# Agent.
self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank)
self.moriio_wrapper.set_moriio_engine(self.moriio_engine)
self.moriio_wrapper.set_backend_type(BackendType.RDMA)
self.moriio_wrapper.notify_port = self.moriio_config.notify_port
self.local_kv_cache_metadata = []
self.local_kv_cache_size = []
self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict()
@ -1450,7 +1446,6 @@ class MoRIIOConnectorWorker:
self.use_mla = self.model_config.use_mla
self.built_session = False
self.built_write_session: defaultdict[str, list] = defaultdict(list)
self._write_session_lock = threading.Lock()
self.debug_cache = []
backend = get_attn_backend(
self.model_config.get_head_size(),
@ -1467,15 +1462,6 @@ class MoRIIOConnectorWorker:
# self._use_flashinfer = attn_backend == _Backend.FLASHINFER
logger.debug("Detected attention backend %s", self.backend_name)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
####write worker###
# self._write_task_q: Queue[WriteTask] = Queue()
# self._write_worker_started = False
# self._write_worker_lock = threading.Lock()
# self._deferred_tasks: list[WriteTask] = []
# ####write worker###
def schedule_write_blocks(
self,
request_id: str,
@ -1539,8 +1525,6 @@ class MoRIIOConnectorWorker:
return self.built_write_session[remote_engine_id]
def _ping(self, zmq_context):
PING_INTERVAL = 5
MAX_RETRIES = 100000
http_request_address = f"http://{self.request_address}/v1/completions"
role = "P" if self.is_producer else "D"
@ -1585,13 +1569,13 @@ class MoRIIOConnectorWorker:
retry_count += 1
finally:
if retry_count >= MAX_RETRIES:
if retry_count >= MoRIIOConstants.MAX_RETRIES:
logger.error(
f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop."
f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop."
)
break
time.sleep(PING_INTERVAL)
time.sleep(MoRIIOConstants.PING_INTERVAL)
index += 1
def handle_proxy_request(self):
@ -1603,7 +1587,6 @@ class MoRIIOConnectorWorker:
socks = dict(self.poller.poll())
logger.debug(f"handle_proxy_request: {socks = }")
# TODO: inkcherry , check here?
if self.metadata_socket not in socks:
continue
@ -1665,8 +1648,6 @@ class MoRIIOConnectorWorker:
"MoRIIO handshake listener received done recv for req %s",
req_id.decode(),
)
else:
pass
def _moriio_handshake(
self,
@ -1709,9 +1690,7 @@ class MoRIIOConnectorWorker:
)
self.moriio_wrapper.remote_engine_ip = host
remote_agent_name = self.moriio_wrapper.register_remote_engine(
metadata.agent_metadata
)
remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key
logger.info(
@ -1873,8 +1852,6 @@ class MoRIIOConnectorWorker:
if layer_name not in self.layer_name_to_local_kv_cache_metadata:
self.layer_name_to_local_kv_cache_metadata[layer_name] = []
# for cache in cache_list:
# moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache)
moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache)
self.layer_name_to_local_kv_cache_metadata[layer_name].append(
moriio_mem_metadata
@ -1942,7 +1919,7 @@ class MoRIIOConnectorWorker:
to track which workers are done.
"""
done_sending, done_recving = set(), set()
done_sending = set()
if self.is_producer:
done_sending = self.moriio_wrapper.pop_finished_req_ids()
@ -1991,7 +1968,6 @@ class MoRIIOConnectorWorker:
remote_engine_id = None
for req_id, meta in metadata.reqs_to_save.items():
remote_engine_id = meta.remote_engine_id
# we only need to check if dp0 in rank
remote_engine_id = (
str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
@ -2030,13 +2006,15 @@ class MoRIIOConnectorWorker:
break
else:
break
def get_engine_name_with_dp(self, engine_name,dp_rank):
return f"{engine_name}_dp{dp_rank}"
def start_load_kv(self, metadata: MoRIIOConnectorMetadata):
"""
Start loading by triggering non-blocking moriio_xfer.
We check for these trnxs to complete in each step().
"""
# print("start load kv")
if self.is_producer:
self.moriio_wrapper.async_wait_reqid()
return
@ -2051,7 +2029,7 @@ class MoRIIOConnectorWorker:
str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
)
meta.remote_engine_id = remote_engine_id
dp0_remote_engine_id = f"{remote_engine_id}_dp0"
dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0)
if dp0_remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock:
@ -2216,7 +2194,6 @@ class MoRIIOConnectorWorker:
blknum, blksize, hs = self.kv_cache_shape
hn = 1
block_stride = stride[0]
ktov_stride = None
else:
_, blknum, blksize, hn, hs = self.kv_cache_shape
ktov_stride = stride[0]
@ -2260,21 +2237,20 @@ class MoRIIOConnectorWorker:
return
# we only test TP<->TP in read mode
# assert self.dp_rank>0, "only test TP<->TP in read mode"
dst_engine_id += "_dp0"
sessions = self._get_built_session(dst_engine_id)
dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0)
sessions = self._get_built_session(dp0_engine_id)
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
offs = self._compute_block_transfer_offsets(
first_layer, local_block_ids, remote_block_ids
)
a, b, c = offs[0], offs[1], offs[2]
for layer_name in self.layer_name_to_local_kv_cache_metadata.keys():
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
layer_name
)
transfer_status = self.moriio_wrapper.read_remote_data(
c, a, b, sessions[sess_idx]
offs[0], offs[1], offs[2], sessions[sess_idx]
)
self._recving_transfers[request_id].append(transfer_status)