[Security] Use safe serialization and fix zmq setup for mooncake pipe (#17192)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
Russell Bryant 2025-04-25 12:53:23 -04:00 committed by GitHub
parent 9d98ab5ec6
commit a5450f11c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
import json import json
import os import os
import struct
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
@ -115,14 +116,14 @@ class MooncakeTransferEngine:
p_rank_offset = int(p_port) + 8 + self.local_rank * 2 p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0: if kv_rank == 0:
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else: else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str, def initialize(self, local_hostname: str, metadata_server: str,
@ -176,7 +177,7 @@ class MooncakeTransferEngine:
def wait_for_ack(self, src_ptr: int, length: int) -> None: def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver.""" """Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv_pyobj() ack = self.sender_ack.recv()
if ack != b'ACK': if ack != b'ACK':
logger.error("Failed to receive ACK from the receiver") logger.error("Failed to receive ACK from the receiver")
@ -187,18 +188,22 @@ class MooncakeTransferEngine:
length = len(user_data) length = len(user_data)
src_ptr = self.allocate_managed_buffer(length) src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length) self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_pyobj((src_ptr, length)) self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr),
struct.pack("!Q", length)])
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes: def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process.""" """Receive bytes from the remote process."""
src_ptr, length = self.receiver_socket.recv_pyobj() data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length) dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length) self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length) ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup # Buffer cleanup
self.receiver_ack.send_pyobj(b'ACK') self.receiver_ack.send(b'ACK')
self.free_managed_buffer(dst_ptr, length) self.free_managed_buffer(dst_ptr, length)
return ret return ret