mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[Perf] Exploit out-of-band buffers in shm_broadcast (#26961)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
4ffd6e8942
commit
ab81379ea6
@ -1,13 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from pickle import PickleBuffer
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -33,8 +34,18 @@ from vllm.utils import (
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SizedBuffer
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
|
||||
|
||||
|
||||
def to_bytes_big(value: int, size: int) -> bytes:
|
||||
return value.to_bytes(size, byteorder="big")
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -225,7 +236,7 @@ class MessageQueue:
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: list[int] | None = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 24, # 24MiB
|
||||
max_chunks: int = 10,
|
||||
connect_ip: str | None = None,
|
||||
):
|
||||
@ -505,18 +516,41 @@ class MessageQueue:
|
||||
def enqueue(self, obj, timeout: float | None = None):
|
||||
"""Write to message queue with optional timeout (in seconds)"""
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
all_buffers: list[SizedBuffer] = [b""]
|
||||
total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size
|
||||
|
||||
def oob_callback(buf: PickleBuffer) -> bool:
|
||||
raw_buf = buf.raw()
|
||||
if len(raw_buf) < 1024 * 1024:
|
||||
# In-line buffers smaller than 1MiB.
|
||||
return True
|
||||
all_buffers.append(raw_buf)
|
||||
nonlocal total_bytes
|
||||
total_bytes += len(raw_buf) + 4
|
||||
return False
|
||||
|
||||
all_buffers[0] = pickle.dumps(
|
||||
obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
|
||||
)
|
||||
if self.n_local_reader > 0:
|
||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||
if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send(serialized_obj)
|
||||
self.local_socket.send_multipart(all_buffers, copy=False)
|
||||
else:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1 : len(serialized_obj) + 1] = serialized_obj
|
||||
offset = 3
|
||||
buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count
|
||||
for buffer in all_buffers:
|
||||
buf_len = len(buffer)
|
||||
# prepend each buffer with 4 bytes containing its size.
|
||||
buf_offset = offset + 4
|
||||
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
|
||||
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
|
||||
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
self.remote_socket.send_multipart(all_buffers, copy=False)
|
||||
|
||||
def dequeue(
|
||||
self,
|
||||
@ -529,10 +563,15 @@ class MessageQueue:
|
||||
with self.acquire_read(timeout, cancel, indefinite) as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf[1:])
|
||||
offset = 3
|
||||
buf_count = from_bytes_big(buf[1:offset])
|
||||
all_buffers = []
|
||||
for i in range(buf_count):
|
||||
buf_offset = offset + 4
|
||||
buf_len = from_bytes_big(buf[offset:buf_offset])
|
||||
offset = buf_offset + buf_len
|
||||
all_buffers.append(buf[buf_offset:offset])
|
||||
obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
|
||||
if overflow:
|
||||
obj = MessageQueue.recv(self.local_socket, timeout)
|
||||
elif self._is_remote_reader:
|
||||
@ -546,15 +585,14 @@ class MessageQueue:
|
||||
timeout_ms = None if timeout is None else int(timeout * 1000)
|
||||
if not socket.poll(timeout=timeout_ms):
|
||||
raise TimeoutError
|
||||
recv = socket.recv(copy=False)
|
||||
return pickle.loads(recv.buffer)
|
||||
recv, *recv_oob = socket.recv_multipart(copy=False)
|
||||
return pickle.loads(recv, buffers=recv_oob)
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
if self._is_writer:
|
||||
self.enqueue(obj)
|
||||
return obj
|
||||
else:
|
||||
return self.dequeue()
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user