[Perf] Exploit out-of-band buffers in shm_broadcast (#26961)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-16 20:08:03 -07:00 committed by GitHub
parent 4ffd6e8942
commit ab81379ea6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pickle
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from pickle import PickleBuffer
from threading import Event
from typing import Any
from typing import TYPE_CHECKING, Any
from unittest.mock import patch
import torch
@ -33,8 +34,18 @@ from vllm.utils import (
is_valid_ipv6_address,
)
if TYPE_CHECKING:
from _typeshed import SizedBuffer
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big")
logger = init_logger(__name__)
@ -225,7 +236,7 @@ class MessageQueue:
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: list[int] | None = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunk_bytes: int = 1024 * 1024 * 24, # 24MiB
max_chunks: int = 10,
connect_ip: str | None = None,
):
@ -505,18 +516,41 @@ class MessageQueue:
def enqueue(self, obj, timeout: float | None = None):
"""Write to message queue with optional timeout (in seconds)"""
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
all_buffers: list[SizedBuffer] = [b""]
total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size
def oob_callback(buf: PickleBuffer) -> bool:
raw_buf = buf.raw()
if len(raw_buf) < 1024 * 1024:
# In-line buffers smaller than 1MiB.
return True
all_buffers.append(raw_buf)
nonlocal total_bytes
total_bytes += len(raw_buf) + 4
return False
all_buffers[0] = pickle.dumps(
obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
with self.acquire_write(timeout) as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
self.local_socket.send_multipart(all_buffers, copy=False)
else:
with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow
buf[1 : len(serialized_obj) + 1] = serialized_obj
offset = 3
buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count
for buffer in all_buffers:
buf_len = len(buffer)
# prepend each buffer with 4 bytes containing its size.
buf_offset = offset + 4
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
self.remote_socket.send_multipart(all_buffers, copy=False)
def dequeue(
self,
@ -529,10 +563,15 @@ class MessageQueue:
with self.acquire_read(timeout, cancel, indefinite) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
# pickle format contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
offset = 3
buf_count = from_bytes_big(buf[1:offset])
all_buffers = []
for i in range(buf_count):
buf_offset = offset + 4
buf_len = from_bytes_big(buf[offset:buf_offset])
offset = buf_offset + buf_len
all_buffers.append(buf[buf_offset:offset])
obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
if overflow:
obj = MessageQueue.recv(self.local_socket, timeout)
elif self._is_remote_reader:
@ -546,15 +585,14 @@ class MessageQueue:
timeout_ms = None if timeout is None else int(timeout * 1000)
if not socket.poll(timeout=timeout_ms):
raise TimeoutError
recv = socket.recv(copy=False)
return pickle.loads(recv.buffer)
recv, *recv_oob = socket.recv_multipart(copy=False)
return pickle.loads(recv, buffers=recv_oob)
def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)
return obj
else:
return self.dequeue()
return self.dequeue()
@staticmethod
def create_from_process_group(