diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 114516ff07a1f..31c6084c9b507 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import pickle +import threading import time from contextlib import contextmanager from dataclasses import dataclass, field @@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL from_bytes_big = functools.partial(int.from_bytes, byteorder="big") +# Memory fence for cross-process shared memory visibility. +# Required for correct producer-consumer synchronization when using +# shared memory without locks. +_memory_fence_lock = threading.Lock() + + +def memory_fence(): + """ + Full memory barrier for shared memory synchronization. + + Ensures all prior memory writes are visible to other processes before + any subsequent reads. This is critical for lock-free producer-consumer + patterns using shared memory. + + Implementation acquires and immediately releases a lock. Python's + threading.Lock provides sequentially consistent memory barrier semantics + across all major platforms (POSIX, Windows). This is a lightweight + operation (~20ns) that guarantees: + - All stores before the barrier are visible to other threads/processes + - All loads after the barrier see the latest values + """ + # Lock acquire/release provides full memory barrier semantics. + # Using context manager ensures lock release even on exceptions. + with _memory_fence_lock: + pass + + def to_bytes_big(value: int, size: int) -> bytes: return value.to_bytes(size, byteorder="big") @@ -414,6 +442,10 @@ class MessageQueue: n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest read flags from readers. + # Without this, we may read stale flags from our CPU cache and + # spin indefinitely even though readers have completed. + memory_fence() read_count = sum(metadata_buffer[1:]) written_flag = metadata_buffer[0] if written_flag and read_count != self.buffer.n_reader: @@ -458,6 +490,10 @@ class MessageQueue: metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 + # Memory fence ensures the write is visible to readers on other cores + # before we proceed. Without this, readers may spin indefinitely + # waiting for a write that's stuck in our CPU's store buffer. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @@ -473,6 +509,10 @@ class MessageQueue: n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest writes from the writer. + # Without this, we may read stale flags from our CPU cache + # and spin indefinitely even though writer has updated them. + memory_fence() read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: @@ -513,6 +553,10 @@ class MessageQueue: # caller has read from the buffer # set the read flag metadata_buffer[self.local_reader_rank + 1] = 1 + # Memory fence ensures the read flag is visible to the writer. + # Without this, writer may not see our read completion and + # could wait indefinitely for all readers to finish. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self._read_spin_timer.record_activity()