[Perf] Exploit out-of-band buffers in shm_broadcast (#26961)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-16 20:08:03 -07:00 committed by GitHub
parent 4ffd6e8942
commit ab81379ea6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pickle import pickle
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import shared_memory from multiprocessing import shared_memory
from pickle import PickleBuffer
from threading import Event from threading import Event
from typing import Any from typing import TYPE_CHECKING, Any
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -33,8 +34,18 @@ from vllm.utils import (
is_valid_ipv6_address, is_valid_ipv6_address,
) )
if TYPE_CHECKING:
from _typeshed import SizedBuffer
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big")
logger = init_logger(__name__) logger = init_logger(__name__)
@ -225,7 +236,7 @@ class MessageQueue:
n_reader, # number of all readers n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory n_local_reader, # number of local readers through shared memory
local_reader_ranks: list[int] | None = None, local_reader_ranks: list[int] | None = None,
max_chunk_bytes: int = 1024 * 1024 * 10, max_chunk_bytes: int = 1024 * 1024 * 24, # 24MiB
max_chunks: int = 10, max_chunks: int = 10,
connect_ip: str | None = None, connect_ip: str | None = None,
): ):
@ -505,18 +516,41 @@ class MessageQueue:
def enqueue(self, obj, timeout: float | None = None): def enqueue(self, obj, timeout: float | None = None):
"""Write to message queue with optional timeout (in seconds)""" """Write to message queue with optional timeout (in seconds)"""
assert self._is_writer, "Only writers can enqueue" assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) all_buffers: list[SizedBuffer] = [b""]
total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size
def oob_callback(buf: PickleBuffer) -> bool:
raw_buf = buf.raw()
if len(raw_buf) < 1024 * 1024:
# In-line buffers smaller than 1MiB.
return True
all_buffers.append(raw_buf)
nonlocal total_bytes
total_bytes += len(raw_buf) + 4
return False
all_buffers[0] = pickle.dumps(
obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
)
if self.n_local_reader > 0: if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes: if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
with self.acquire_write(timeout) as buf: with self.acquire_write(timeout) as buf:
buf[0] = 1 # overflow buf[0] = 1 # overflow
self.local_socket.send(serialized_obj) self.local_socket.send_multipart(all_buffers, copy=False)
else: else:
with self.acquire_write(timeout) as buf: with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow buf[0] = 0 # not overflow
buf[1 : len(serialized_obj) + 1] = serialized_obj offset = 3
buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count
for buffer in all_buffers:
buf_len = len(buffer)
# prepend each buffer with 4 bytes containing its size.
buf_offset = offset + 4
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
if self.n_remote_reader > 0: if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj) self.remote_socket.send_multipart(all_buffers, copy=False)
def dequeue( def dequeue(
self, self,
@ -529,10 +563,15 @@ class MessageQueue:
with self.acquire_read(timeout, cancel, indefinite) as buf: with self.acquire_read(timeout, cancel, indefinite) as buf:
overflow = buf[0] == 1 overflow = buf[0] == 1
if not overflow: if not overflow:
# no need to know the size of serialized object offset = 3
# pickle format contains the size information internally buf_count = from_bytes_big(buf[1:offset])
# see https://docs.python.org/3/library/pickle.html all_buffers = []
obj = pickle.loads(buf[1:]) for i in range(buf_count):
buf_offset = offset + 4
buf_len = from_bytes_big(buf[offset:buf_offset])
offset = buf_offset + buf_len
all_buffers.append(buf[buf_offset:offset])
obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
if overflow: if overflow:
obj = MessageQueue.recv(self.local_socket, timeout) obj = MessageQueue.recv(self.local_socket, timeout)
elif self._is_remote_reader: elif self._is_remote_reader:
@ -546,14 +585,13 @@ class MessageQueue:
timeout_ms = None if timeout is None else int(timeout * 1000) timeout_ms = None if timeout is None else int(timeout * 1000)
if not socket.poll(timeout=timeout_ms): if not socket.poll(timeout=timeout_ms):
raise TimeoutError raise TimeoutError
recv = socket.recv(copy=False) recv, *recv_oob = socket.recv_multipart(copy=False)
return pickle.loads(recv.buffer) return pickle.loads(recv, buffers=recv_oob)
def broadcast_object(self, obj=None): def broadcast_object(self, obj=None):
if self._is_writer: if self._is_writer:
self.enqueue(obj) self.enqueue(obj)
return obj return obj
else:
return self.dequeue() return self.dequeue()
@staticmethod @staticmethod