mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:17:07 +08:00
[Perf] Exploit out-of-band buffers in shm_broadcast (#26961)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
4ffd6e8942
commit
ab81379ea6
@ -1,13 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
|
from pickle import PickleBuffer
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -33,8 +34,18 @@ from vllm.utils import (
|
|||||||
is_valid_ipv6_address,
|
is_valid_ipv6_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from _typeshed import SizedBuffer
|
||||||
|
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||||
|
|
||||||
|
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
|
def to_bytes_big(value: int, size: int) -> bytes:
|
||||||
|
return value.to_bytes(size, byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -225,7 +236,7 @@ class MessageQueue:
|
|||||||
n_reader, # number of all readers
|
n_reader, # number of all readers
|
||||||
n_local_reader, # number of local readers through shared memory
|
n_local_reader, # number of local readers through shared memory
|
||||||
local_reader_ranks: list[int] | None = None,
|
local_reader_ranks: list[int] | None = None,
|
||||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
max_chunk_bytes: int = 1024 * 1024 * 24, # 24MiB
|
||||||
max_chunks: int = 10,
|
max_chunks: int = 10,
|
||||||
connect_ip: str | None = None,
|
connect_ip: str | None = None,
|
||||||
):
|
):
|
||||||
@ -505,18 +516,41 @@ class MessageQueue:
|
|||||||
def enqueue(self, obj, timeout: float | None = None):
|
def enqueue(self, obj, timeout: float | None = None):
|
||||||
"""Write to message queue with optional timeout (in seconds)"""
|
"""Write to message queue with optional timeout (in seconds)"""
|
||||||
assert self._is_writer, "Only writers can enqueue"
|
assert self._is_writer, "Only writers can enqueue"
|
||||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
all_buffers: list[SizedBuffer] = [b""]
|
||||||
|
total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size
|
||||||
|
|
||||||
|
def oob_callback(buf: PickleBuffer) -> bool:
|
||||||
|
raw_buf = buf.raw()
|
||||||
|
if len(raw_buf) < 1024 * 1024:
|
||||||
|
# In-line buffers smaller than 1MiB.
|
||||||
|
return True
|
||||||
|
all_buffers.append(raw_buf)
|
||||||
|
nonlocal total_bytes
|
||||||
|
total_bytes += len(raw_buf) + 4
|
||||||
|
return False
|
||||||
|
|
||||||
|
all_buffers[0] = pickle.dumps(
|
||||||
|
obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
|
||||||
|
)
|
||||||
if self.n_local_reader > 0:
|
if self.n_local_reader > 0:
|
||||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
|
||||||
with self.acquire_write(timeout) as buf:
|
with self.acquire_write(timeout) as buf:
|
||||||
buf[0] = 1 # overflow
|
buf[0] = 1 # overflow
|
||||||
self.local_socket.send(serialized_obj)
|
self.local_socket.send_multipart(all_buffers, copy=False)
|
||||||
else:
|
else:
|
||||||
with self.acquire_write(timeout) as buf:
|
with self.acquire_write(timeout) as buf:
|
||||||
buf[0] = 0 # not overflow
|
buf[0] = 0 # not overflow
|
||||||
buf[1 : len(serialized_obj) + 1] = serialized_obj
|
offset = 3
|
||||||
|
buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count
|
||||||
|
for buffer in all_buffers:
|
||||||
|
buf_len = len(buffer)
|
||||||
|
# prepend each buffer with 4 bytes containing its size.
|
||||||
|
buf_offset = offset + 4
|
||||||
|
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
|
||||||
|
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
|
||||||
|
|
||||||
if self.n_remote_reader > 0:
|
if self.n_remote_reader > 0:
|
||||||
self.remote_socket.send(serialized_obj)
|
self.remote_socket.send_multipart(all_buffers, copy=False)
|
||||||
|
|
||||||
def dequeue(
|
def dequeue(
|
||||||
self,
|
self,
|
||||||
@ -529,10 +563,15 @@ class MessageQueue:
|
|||||||
with self.acquire_read(timeout, cancel, indefinite) as buf:
|
with self.acquire_read(timeout, cancel, indefinite) as buf:
|
||||||
overflow = buf[0] == 1
|
overflow = buf[0] == 1
|
||||||
if not overflow:
|
if not overflow:
|
||||||
# no need to know the size of serialized object
|
offset = 3
|
||||||
# pickle format contains the size information internally
|
buf_count = from_bytes_big(buf[1:offset])
|
||||||
# see https://docs.python.org/3/library/pickle.html
|
all_buffers = []
|
||||||
obj = pickle.loads(buf[1:])
|
for i in range(buf_count):
|
||||||
|
buf_offset = offset + 4
|
||||||
|
buf_len = from_bytes_big(buf[offset:buf_offset])
|
||||||
|
offset = buf_offset + buf_len
|
||||||
|
all_buffers.append(buf[buf_offset:offset])
|
||||||
|
obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
|
||||||
if overflow:
|
if overflow:
|
||||||
obj = MessageQueue.recv(self.local_socket, timeout)
|
obj = MessageQueue.recv(self.local_socket, timeout)
|
||||||
elif self._is_remote_reader:
|
elif self._is_remote_reader:
|
||||||
@ -546,15 +585,14 @@ class MessageQueue:
|
|||||||
timeout_ms = None if timeout is None else int(timeout * 1000)
|
timeout_ms = None if timeout is None else int(timeout * 1000)
|
||||||
if not socket.poll(timeout=timeout_ms):
|
if not socket.poll(timeout=timeout_ms):
|
||||||
raise TimeoutError
|
raise TimeoutError
|
||||||
recv = socket.recv(copy=False)
|
recv, *recv_oob = socket.recv_multipart(copy=False)
|
||||||
return pickle.loads(recv.buffer)
|
return pickle.loads(recv, buffers=recv_oob)
|
||||||
|
|
||||||
def broadcast_object(self, obj=None):
|
def broadcast_object(self, obj=None):
|
||||||
if self._is_writer:
|
if self._is_writer:
|
||||||
self.enqueue(obj)
|
self.enqueue(obj)
|
||||||
return obj
|
return obj
|
||||||
else:
|
return self.dequeue()
|
||||||
return self.dequeue()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_from_process_group(
|
def create_from_process_group(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user