mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 01:55:01 +08:00
[Security] Use safe serialization and fix zmq setup for mooncake pipe (#17192)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
9d98ab5ec6
commit
a5450f11c9
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import struct
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
@ -115,14 +116,14 @@ class MooncakeTransferEngine:
|
|||||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||||
if kv_rank == 0:
|
if kv_rank == 0:
|
||||||
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
|
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||||
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
|
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||||
else:
|
else:
|
||||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||||
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
|
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||||
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
|
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||||
|
|
||||||
def initialize(self, local_hostname: str, metadata_server: str,
|
def initialize(self, local_hostname: str, metadata_server: str,
|
||||||
@ -176,7 +177,7 @@ class MooncakeTransferEngine:
|
|||||||
|
|
||||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
||||||
"""Asynchronously wait for ACK from the receiver."""
|
"""Asynchronously wait for ACK from the receiver."""
|
||||||
ack = self.sender_ack.recv_pyobj()
|
ack = self.sender_ack.recv()
|
||||||
if ack != b'ACK':
|
if ack != b'ACK':
|
||||||
logger.error("Failed to receive ACK from the receiver")
|
logger.error("Failed to receive ACK from the receiver")
|
||||||
|
|
||||||
@ -187,18 +188,22 @@ class MooncakeTransferEngine:
|
|||||||
length = len(user_data)
|
length = len(user_data)
|
||||||
src_ptr = self.allocate_managed_buffer(length)
|
src_ptr = self.allocate_managed_buffer(length)
|
||||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
||||||
self.sender_socket.send_pyobj((src_ptr, length))
|
self.sender_socket.send_multipart(
|
||||||
|
[struct.pack("!Q", src_ptr),
|
||||||
|
struct.pack("!Q", length)])
|
||||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
||||||
|
|
||||||
def recv_bytes(self) -> bytes:
|
def recv_bytes(self) -> bytes:
|
||||||
"""Receive bytes from the remote process."""
|
"""Receive bytes from the remote process."""
|
||||||
src_ptr, length = self.receiver_socket.recv_pyobj()
|
data = self.receiver_socket.recv_multipart()
|
||||||
|
src_ptr = struct.unpack("!Q", data[0])[0]
|
||||||
|
length = struct.unpack("!Q", data[1])[0]
|
||||||
dst_ptr = self.allocate_managed_buffer(length)
|
dst_ptr = self.allocate_managed_buffer(length)
|
||||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
self.transfer_sync(dst_ptr, src_ptr, length)
|
||||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
||||||
|
|
||||||
# Buffer cleanup
|
# Buffer cleanup
|
||||||
self.receiver_ack.send_pyobj(b'ACK')
|
self.receiver_ack.send(b'ACK')
|
||||||
self.free_managed_buffer(dst_ptr, length)
|
self.free_managed_buffer(dst_ptr, length)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user