[Security] Use safe serialization and fix zmq setup for mooncake pipe (#17192)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
Russell Bryant 2025-04-25 12:53:23 -04:00 committed by GitHub
parent 9d98ab5ec6
commit a5450f11c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
import json
import os
import struct
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Union
@ -115,14 +116,14 @@ class MooncakeTransferEngine:
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0:
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str,
@ -176,7 +177,7 @@ class MooncakeTransferEngine:
def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv_pyobj()
ack = self.sender_ack.recv()
if ack != b'ACK':
logger.error("Failed to receive ACK from the receiver")
@ -187,18 +188,22 @@ class MooncakeTransferEngine:
length = len(user_data)
src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_pyobj((src_ptr, length))
self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr),
struct.pack("!Q", length)])
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process."""
src_ptr, length = self.receiver_socket.recv_pyobj()
data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup
self.receiver_ack.send_pyobj(b'ACK')
self.receiver_ack.send(b'ACK')
self.free_managed_buffer(dst_ptr, length)
return ret