refine code

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 04:36:19 +00:00
parent 68a2333339
commit 70ea1b2460

View File

@ -7,6 +7,13 @@ import pickle
import queue import queue
import threading import threading
import time import time
import msgpack
import msgspec
import numpy as np
import torch
import zmq
import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
@ -14,12 +21,6 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from weakref import ref as weakref_ref from weakref import ref as weakref_ref
import msgpack
import msgspec
import numpy as np
import torch
import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -50,7 +51,6 @@ if TYPE_CHECKING:
from dataclasses import field from dataclasses import field
from enum import Enum from enum import Enum
from queue import Empty, Queue from queue import Empty, Queue
import logging
logger = init_logger(__name__) logger = init_logger(__name__)
@ -70,7 +70,9 @@ class MoRIIOConstants:
# Default GPU count per node for standard configurations # Default GPU count per node for standard configurations
RANK_PER_NODE = 8 RANK_PER_NODE = 8
PING_INTERVAL = 5
MAX_PING_RETRIES = 100000
try: try:
import mori import mori
@ -80,7 +82,6 @@ try:
IOEngine, IOEngine,
IOEngineConfig, IOEngineConfig,
MemoryDesc, MemoryDesc,
StatusCode,
) )
logger.info("MoRIIO is available") logger.info("MoRIIO is available")
@ -404,10 +405,12 @@ class MoRIIOWriter:
return return
# Wait for CUDA event # Wait for CUDA event
#The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
#This event is used to synchronize the kv transfer and computation tasks.
task.event.synchronize() task.event.synchronize()
# Update engine ID with DP rank # Update engine ID with DP rank
task.dst_engine_id = f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank)
# Get or create sessions # Get or create sessions
sessions = self.worker._get_built_session(task.dst_engine_id) sessions = self.worker._get_built_session(task.dst_engine_id)
@ -665,7 +668,6 @@ class MoRIIOWrapper:
except Exception as e: except Exception as e:
logger.error(f"Error processing message: {e}") logger.error(f"Error processing message: {e}")
raise HandshakeError(f"Error processing message: {e}") from e raise HandshakeError(f"Error processing message: {e}") from e
continue
self.notify_thread = threading.Thread( self.notify_thread = threading.Thread(
target=_async_wait, daemon=True, name="moriio-notify-listener" target=_async_wait, daemon=True, name="moriio-notify-listener"
@ -685,7 +687,6 @@ class MoRIIOWrapper:
data = msgpack.loads(msg) data = msgpack.loads(msg)
if isinstance(data, dict) and "req_id" in data: if isinstance(data, dict) and "req_id" in data:
self._handle_structured_message(data) self._handle_structured_message(data)
handled = True
return return
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
@ -841,6 +842,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: Optional["KVCacheConfig"] = None,
): ):
super().__init__(vllm_config, role)
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
# assert vllm_config.kv_transfer_config.engine_id is not None # assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id = ( self.engine_id = (
@ -1149,10 +1151,8 @@ class MoRIIOConnectorScheduler:
) )
self._reqs_need_pending_save[req_id] = (req, updated_blocks) self._reqs_need_pending_save[req_id] = (req, updated_blocks)
if ( if (
len( len(self._reqs_need_pending_save[req_id][1])
self._reqs_need_pending_save[req_id][1]
* self.block_size * self.block_size
)
>= req.num_prompt_tokens >= req.num_prompt_tokens
): ):
meta.add_new_req( meta.add_new_req(
@ -1363,13 +1363,9 @@ class MoRIIOConnectorWorker:
) )
# Agent. # Agent.
self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank)
self.moriio_wrapper.set_moriio_engine(self.moriio_engine) self.moriio_wrapper.set_moriio_engine(self.moriio_engine)
self.moriio_wrapper.set_backend_type(BackendType.RDMA) self.moriio_wrapper.set_backend_type(BackendType.RDMA)
self.moriio_wrapper.notify_port = self.moriio_config.notify_port self.moriio_wrapper.notify_port = self.moriio_config.notify_port
self.local_kv_cache_metadata = [] self.local_kv_cache_metadata = []
self.local_kv_cache_size = [] self.local_kv_cache_size = []
self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict()
@ -1450,7 +1446,6 @@ class MoRIIOConnectorWorker:
self.use_mla = self.model_config.use_mla self.use_mla = self.model_config.use_mla
self.built_session = False self.built_session = False
self.built_write_session: defaultdict[str, list] = defaultdict(list) self.built_write_session: defaultdict[str, list] = defaultdict(list)
self._write_session_lock = threading.Lock()
self.debug_cache = [] self.debug_cache = []
backend = get_attn_backend( backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
@ -1467,15 +1462,6 @@ class MoRIIOConnectorWorker:
# self._use_flashinfer = attn_backend == _Backend.FLASHINFER # self._use_flashinfer = attn_backend == _Backend.FLASHINFER
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
####write worker###
# self._write_task_q: Queue[WriteTask] = Queue()
# self._write_worker_started = False
# self._write_worker_lock = threading.Lock()
# self._deferred_tasks: list[WriteTask] = []
# ####write worker###
def schedule_write_blocks( def schedule_write_blocks(
self, self,
request_id: str, request_id: str,
@ -1539,8 +1525,6 @@ class MoRIIOConnectorWorker:
return self.built_write_session[remote_engine_id] return self.built_write_session[remote_engine_id]
def _ping(self, zmq_context): def _ping(self, zmq_context):
PING_INTERVAL = 5
MAX_RETRIES = 100000
http_request_address = f"http://{self.request_address}/v1/completions" http_request_address = f"http://{self.request_address}/v1/completions"
role = "P" if self.is_producer else "D" role = "P" if self.is_producer else "D"
@ -1585,13 +1569,13 @@ class MoRIIOConnectorWorker:
retry_count += 1 retry_count += 1
finally: finally:
if retry_count >= MAX_RETRIES: if retry_count >= MoRIIOConstants.MAX_RETRIES:
logger.error( logger.error(
f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop." f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop."
) )
break break
time.sleep(PING_INTERVAL) time.sleep(MoRIIOConstants.PING_INTERVAL)
index += 1 index += 1
def handle_proxy_request(self): def handle_proxy_request(self):
@ -1603,7 +1587,6 @@ class MoRIIOConnectorWorker:
socks = dict(self.poller.poll()) socks = dict(self.poller.poll())
logger.debug(f"handle_proxy_request: {socks = }") logger.debug(f"handle_proxy_request: {socks = }")
# TODO: inkcherry , check here?
if self.metadata_socket not in socks: if self.metadata_socket not in socks:
continue continue
@ -1665,8 +1648,6 @@ class MoRIIOConnectorWorker:
"MoRIIO handshake listener received done recv for req %s", "MoRIIO handshake listener received done recv for req %s",
req_id.decode(), req_id.decode(),
) )
else:
pass
def _moriio_handshake( def _moriio_handshake(
self, self,
@ -1709,9 +1690,7 @@ class MoRIIOConnectorWorker:
) )
self.moriio_wrapper.remote_engine_ip = host self.moriio_wrapper.remote_engine_ip = host
remote_agent_name = self.moriio_wrapper.register_remote_engine(
metadata.agent_metadata
)
remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key
logger.info( logger.info(
@ -1873,8 +1852,6 @@ class MoRIIOConnectorWorker:
if layer_name not in self.layer_name_to_local_kv_cache_metadata: if layer_name not in self.layer_name_to_local_kv_cache_metadata:
self.layer_name_to_local_kv_cache_metadata[layer_name] = [] self.layer_name_to_local_kv_cache_metadata[layer_name] = []
# for cache in cache_list:
# moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache)
moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache) moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache)
self.layer_name_to_local_kv_cache_metadata[layer_name].append( self.layer_name_to_local_kv_cache_metadata[layer_name].append(
moriio_mem_metadata moriio_mem_metadata
@ -1942,7 +1919,7 @@ class MoRIIOConnectorWorker:
to track which workers are done. to track which workers are done.
""" """
done_sending, done_recving = set(), set() done_sending = set()
if self.is_producer: if self.is_producer:
done_sending = self.moriio_wrapper.pop_finished_req_ids() done_sending = self.moriio_wrapper.pop_finished_req_ids()
@ -1991,7 +1968,6 @@ class MoRIIOConnectorWorker:
remote_engine_id = None remote_engine_id = None
for req_id, meta in metadata.reqs_to_save.items(): for req_id, meta in metadata.reqs_to_save.items():
remote_engine_id = meta.remote_engine_id
# we only need to check if dp0 in rank # we only need to check if dp0 in rank
remote_engine_id = ( remote_engine_id = (
str(meta.remote_host) + ":" + str(meta.remote_handshake_port) str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
@ -2030,13 +2006,15 @@ class MoRIIOConnectorWorker:
break break
else: else:
break break
def get_engine_name_with_dp(self, engine_name,dp_rank):
return f"{engine_name}_dp{dp_rank}"
def start_load_kv(self, metadata: MoRIIOConnectorMetadata): def start_load_kv(self, metadata: MoRIIOConnectorMetadata):
""" """
Start loading by triggering non-blocking moriio_xfer. Start loading by triggering non-blocking moriio_xfer.
We check for these trnxs to complete in each step(). We check for these trnxs to complete in each step().
""" """
# print("start load kv")
if self.is_producer: if self.is_producer:
self.moriio_wrapper.async_wait_reqid() self.moriio_wrapper.async_wait_reqid()
return return
@ -2051,7 +2029,7 @@ class MoRIIOConnectorWorker:
str(meta.remote_host) + ":" + str(meta.remote_handshake_port) str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
) )
meta.remote_engine_id = remote_engine_id meta.remote_engine_id = remote_engine_id
dp0_remote_engine_id = f"{remote_engine_id}_dp0" dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0)
if dp0_remote_engine_id not in self._remote_agents: if dp0_remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata. # Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock: with self._handshake_lock:
@ -2216,7 +2194,6 @@ class MoRIIOConnectorWorker:
blknum, blksize, hs = self.kv_cache_shape blknum, blksize, hs = self.kv_cache_shape
hn = 1 hn = 1
block_stride = stride[0] block_stride = stride[0]
ktov_stride = None
else: else:
_, blknum, blksize, hn, hs = self.kv_cache_shape _, blknum, blksize, hn, hs = self.kv_cache_shape
ktov_stride = stride[0] ktov_stride = stride[0]
@ -2260,21 +2237,20 @@ class MoRIIOConnectorWorker:
return return
# we only test TP<->TP in read mode # we only test TP<->TP in read mode
# assert self.dp_rank>0, "only test TP<->TP in read mode" # assert self.dp_rank>0, "only test TP<->TP in read mode"
dst_engine_id += "_dp0" dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0)
sessions = self._get_built_session(dst_engine_id) sessions = self._get_built_session(dp0_engine_id)
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
offs = self._compute_block_transfer_offsets( offs = self._compute_block_transfer_offsets(
first_layer, local_block_ids, remote_block_ids first_layer, local_block_ids, remote_block_ids
) )
a, b, c = offs[0], offs[1], offs[2]
for layer_name in self.layer_name_to_local_kv_cache_metadata.keys(): for layer_name in self.layer_name_to_local_kv_cache_metadata.keys():
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
layer_name layer_name
) )
transfer_status = self.moriio_wrapper.read_remote_data( transfer_status = self.moriio_wrapper.read_remote_data(
c, a, b, sessions[sess_idx] offs[0], offs[1], offs[2], sessions[sess_idx]
) )
self._recving_transfers[request_id].append(transfer_status) self._recving_transfers[request_id].append(transfer_status)