mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 12:37:06 +08:00
refine
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
64694c3e76
commit
245b71a891
@ -68,8 +68,6 @@ class MoRIIOConstants:
|
||||
OVER = b"OVER"
|
||||
COMPLETION_PREFIX = "cmpl"
|
||||
|
||||
# Default GPU count per node for standard configurations
|
||||
RANK_PER_NODE = 8
|
||||
|
||||
PING_INTERVAL = 5
|
||||
MAX_PING_RETRIES = 100000
|
||||
@ -204,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode:
|
||||
|
||||
|
||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||
return ((dp_rank) * tp_size + tp_rank) % MoRIIOConstants.RANK_PER_NODE
|
||||
return ((dp_rank) * tp_size + tp_rank)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -405,8 +403,8 @@ class MoRIIOWriter:
|
||||
return
|
||||
|
||||
# Wait for CUDA event
|
||||
#The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
|
||||
#This event is used to synchronize the kv transfer and computation tasks.
|
||||
# The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
|
||||
# This event is used to synchronize the kv transfer and computation tasks.
|
||||
task.event.synchronize()
|
||||
|
||||
# Update engine ID with DP rank
|
||||
@ -546,8 +544,7 @@ class MoRIIOWrapper:
|
||||
self.done_write_cache_req_ids = []
|
||||
self.notify_thread = None
|
||||
self.sock = None
|
||||
self.sessions = []
|
||||
self.kv_caches = None
|
||||
self.sessions: list["IOEngine.Session"] = []
|
||||
self.paths = {}
|
||||
|
||||
def set_moriio_engine(self, moriio_engine):
|
||||
@ -730,7 +727,7 @@ class MoRIIOWrapper:
|
||||
path = make_zmq_path("tcp", remote_ip, str(remote_port))
|
||||
|
||||
if path not in self.paths:
|
||||
ctx = zmq.Context()
|
||||
ctx = zmq.Context.instance()
|
||||
sock = make_zmq_socket(
|
||||
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
|
||||
)
|
||||
@ -1033,7 +1030,7 @@ class MoRIIOConnectorScheduler:
|
||||
):
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
if path not in self.paths:
|
||||
ctx = zmq.Context()
|
||||
ctx = zmq.Context.instance()
|
||||
sock = make_zmq_socket(
|
||||
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
|
||||
)
|
||||
@ -1268,7 +1265,7 @@ class MoRIIOConnectorWorker:
|
||||
self.mode = get_moriio_mode()
|
||||
|
||||
logger.info("Initializing MoRIIO worker %s", engine_id)
|
||||
# for debug
|
||||
|
||||
logging.getLogger("aiter").disabled = True
|
||||
|
||||
# Config.
|
||||
@ -1377,7 +1374,7 @@ class MoRIIOConnectorWorker:
|
||||
)
|
||||
self.slot_size_bytes = 0
|
||||
|
||||
self.load_ready_flag = False
|
||||
self.load_ready_flag = {}
|
||||
self.write_ready_flags = {}
|
||||
self.kv_cache_shape = None
|
||||
self.block_shape = None
|
||||
@ -1569,9 +1566,9 @@ class MoRIIOConnectorWorker:
|
||||
retry_count += 1
|
||||
|
||||
finally:
|
||||
if retry_count >= MoRIIOConstants.MAX_RETRIES:
|
||||
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
|
||||
logger.error(
|
||||
f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop."
|
||||
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop."
|
||||
)
|
||||
break
|
||||
|
||||
@ -1590,12 +1587,19 @@ class MoRIIOConnectorWorker:
|
||||
if self.metadata_socket not in socks:
|
||||
continue
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, '_handshake_initiation_executor'):
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
|
||||
if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t:
|
||||
self._moriio_handshake_listener_t.join(timeout=0)
|
||||
|
||||
if hasattr(self, 'zmq_context') and self.zmq_context:
|
||||
self.zmq_context.destroy(linger=0)
|
||||
self.zmq_context = None
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if self._moriio_handshake_listener_t:
|
||||
self._moriio_handshake_listener_t.join(timeout=0)
|
||||
self.close()
|
||||
|
||||
@staticmethod
|
||||
def _moriio_handshake_listener(
|
||||
@ -1744,7 +1748,7 @@ class MoRIIOConnectorWorker:
|
||||
def request_ready(_f: Future[Any], entry=(req_id, meta)):
|
||||
logger.info("MoRIIO handshake done for request %s", req_id)
|
||||
self._ready_requests.put(entry)
|
||||
self.load_ready_flag = True
|
||||
self.load_ready_flag [remote_engine_id] = True
|
||||
self.write_ready_flags[remote_engine_id] = True
|
||||
|
||||
fut_list = []
|
||||
@ -1826,15 +1830,6 @@ class MoRIIOConnectorWorker:
|
||||
kv_caches_base_addr = []
|
||||
caches_data = []
|
||||
|
||||
# Note(tms): I modified this from the original region setup code.
|
||||
# K and V are now in different regions. Advantage is that we can
|
||||
# elegantly support MLA and any cases where the K and V tensors
|
||||
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||
# Disadvantage is that the encoded MoRIIOAgentMetadata is now larger
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
|
||||
for cache_or_caches in kv_caches.values():
|
||||
cache_list = (
|
||||
[cache_or_caches]
|
||||
@ -2043,18 +2038,20 @@ class MoRIIOConnectorWorker:
|
||||
self._read_blocks_for_req(req_id, meta)
|
||||
# Start transfers for requests whose handshakes have now finished.
|
||||
|
||||
while True: # TODO
|
||||
while True:
|
||||
if (
|
||||
self._ready_requests.empty()
|
||||
and not self.load_ready_flag
|
||||
and remote_engine_id not in self.load_ready_flag
|
||||
and wait_handshake_readd_req
|
||||
):
|
||||
continue
|
||||
elif not self._ready_requests.empty() and self.load_ready_flag:
|
||||
elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag:
|
||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
|
||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
||||
|
||||
@ -2233,8 +2230,8 @@ class MoRIIOConnectorWorker:
|
||||
) -> None:
|
||||
if self.mode == MoRIIOMode.WRITE:
|
||||
return
|
||||
# we only test TP<->TP in read mode
|
||||
# assert self.dp_rank>0, "only test TP<->TP in read mode"
|
||||
|
||||
|
||||
dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0)
|
||||
sessions = self._get_built_session(dp0_engine_id)
|
||||
|
||||
@ -2248,7 +2245,7 @@ class MoRIIOConnectorWorker:
|
||||
layer_name
|
||||
)
|
||||
transfer_status = self.moriio_wrapper.read_remote_data(
|
||||
offs[0], offs[1], offs[2], sessions[sess_idx]
|
||||
offs[2], offs[0], offs[1], sessions[sess_idx]
|
||||
)
|
||||
|
||||
self._recving_transfers[request_id].append(transfer_status)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user