diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index cd201503bf17..5dba80fe8080 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import functools import pickle import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory +from pickle import PickleBuffer from threading import Event -from typing import Any +from typing import TYPE_CHECKING, Any from unittest.mock import patch import torch @@ -33,8 +34,18 @@ from vllm.utils import ( is_valid_ipv6_address, ) +if TYPE_CHECKING: + from _typeshed import SizedBuffer + VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL +from_bytes_big = functools.partial(int.from_bytes, byteorder="big") + + +def to_bytes_big(value: int, size: int) -> bytes: + return value.to_bytes(size, byteorder="big") + + logger = init_logger(__name__) @@ -225,7 +236,7 @@ class MessageQueue: n_reader, # number of all readers n_local_reader, # number of local readers through shared memory local_reader_ranks: list[int] | None = None, - max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunk_bytes: int = 1024 * 1024 * 24, # 24MiB max_chunks: int = 10, connect_ip: str | None = None, ): @@ -505,18 +516,41 @@ class MessageQueue: def enqueue(self, obj, timeout: float | None = None): """Write to message queue with optional timeout (in seconds)""" assert self._is_writer, "Only writers can enqueue" - serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + all_buffers: list[SizedBuffer] = [b""] + total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size + + def oob_callback(buf: PickleBuffer) -> bool: + raw_buf = buf.raw() + if len(raw_buf) < 1024 * 1024: + # In-line buffers smaller than 1MiB. + return True + all_buffers.append(raw_buf) + nonlocal total_bytes + total_bytes += len(raw_buf) + 4 + return False + + all_buffers[0] = pickle.dumps( + obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback + ) if self.n_local_reader > 0: - if len(serialized_obj) >= self.buffer.max_chunk_bytes: + if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes: with self.acquire_write(timeout) as buf: buf[0] = 1 # overflow - self.local_socket.send(serialized_obj) + self.local_socket.send_multipart(all_buffers, copy=False) else: with self.acquire_write(timeout) as buf: buf[0] = 0 # not overflow - buf[1 : len(serialized_obj) + 1] = serialized_obj + offset = 3 + buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count + for buffer in all_buffers: + buf_len = len(buffer) + # prepend each buffer with 4 bytes containing its size. + buf_offset = offset + 4 + buf[offset:buf_offset] = to_bytes_big(buf_len, 4) + buf[buf_offset : (offset := buf_offset + buf_len)] = buffer + if self.n_remote_reader > 0: - self.remote_socket.send(serialized_obj) + self.remote_socket.send_multipart(all_buffers, copy=False) def dequeue( self, @@ -529,10 +563,15 @@ class MessageQueue: with self.acquire_read(timeout, cancel, indefinite) as buf: overflow = buf[0] == 1 if not overflow: - # no need to know the size of serialized object - # pickle format contains the size information internally - # see https://docs.python.org/3/library/pickle.html - obj = pickle.loads(buf[1:]) + offset = 3 + buf_count = from_bytes_big(buf[1:offset]) + all_buffers = [] + for i in range(buf_count): + buf_offset = offset + 4 + buf_len = from_bytes_big(buf[offset:buf_offset]) + offset = buf_offset + buf_len + all_buffers.append(buf[buf_offset:offset]) + obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:]) if overflow: obj = MessageQueue.recv(self.local_socket, timeout) elif self._is_remote_reader: @@ -546,15 +585,14 @@ class MessageQueue: timeout_ms = None if timeout is None else int(timeout * 1000) if not socket.poll(timeout=timeout_ms): raise TimeoutError - recv = socket.recv(copy=False) - return pickle.loads(recv.buffer) + recv, *recv_oob = socket.recv_multipart(copy=False) + return pickle.loads(recv, buffers=recv_oob) def broadcast_object(self, obj=None): if self._is_writer: self.enqueue(obj) return obj - else: - return self.dequeue() + return self.dequeue() @staticmethod def create_from_process_group(