mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:35:16 +08:00
[Security] Use safe serialization and fix zmq setup for mooncake pipe (#17192)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
9d98ab5ec6
commit
a5450f11c9
@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
@ -115,14 +116,14 @@ class MooncakeTransferEngine:
|
||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
|
||||
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
|
||||
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
else:
|
||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
|
||||
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
|
||||
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
|
||||
def initialize(self, local_hostname: str, metadata_server: str,
|
||||
@ -176,7 +177,7 @@ class MooncakeTransferEngine:
|
||||
|
||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
||||
"""Asynchronously wait for ACK from the receiver."""
|
||||
ack = self.sender_ack.recv_pyobj()
|
||||
ack = self.sender_ack.recv()
|
||||
if ack != b'ACK':
|
||||
logger.error("Failed to receive ACK from the receiver")
|
||||
|
||||
@ -187,18 +188,22 @@ class MooncakeTransferEngine:
|
||||
length = len(user_data)
|
||||
src_ptr = self.allocate_managed_buffer(length)
|
||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
||||
self.sender_socket.send_pyobj((src_ptr, length))
|
||||
self.sender_socket.send_multipart(
|
||||
[struct.pack("!Q", src_ptr),
|
||||
struct.pack("!Q", length)])
|
||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
||||
|
||||
def recv_bytes(self) -> bytes:
|
||||
"""Receive bytes from the remote process."""
|
||||
src_ptr, length = self.receiver_socket.recv_pyobj()
|
||||
data = self.receiver_socket.recv_multipart()
|
||||
src_ptr = struct.unpack("!Q", data[0])[0]
|
||||
length = struct.unpack("!Q", data[1])[0]
|
||||
dst_ptr = self.allocate_managed_buffer(length)
|
||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
||||
|
||||
# Buffer cleanup
|
||||
self.receiver_ack.send_pyobj(b'ACK')
|
||||
self.receiver_ack.send(b'ACK')
|
||||
self.free_managed_buffer(dst_ptr, length)
|
||||
|
||||
return ret
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user