mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
Signed-off-by: Kero Liang <kerorek@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: donglu <donglu@cohere.com> Co-authored-by: Roger Wang <hey@rogerw.io>
244 lines
8.6 KiB
Python
244 lines
8.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import traceback
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from vllm.distributed.device_communicators.shm_object_storage import (
|
|
SingleWriterShmRingBuffer,
|
|
)
|
|
|
|
|
|
class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|
"""Test suite for the ring buffer implementation"""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
self.buffer_size = 4096
|
|
self.ring_buffer = None
|
|
|
|
def tearDown(self):
|
|
"""Clean up after tests"""
|
|
if self.ring_buffer:
|
|
del self.ring_buffer
|
|
|
|
def test_buffer_opening(self):
|
|
"""Test opening an existing buffer"""
|
|
# First create a buffer
|
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
|
data_buffer_size=self.buffer_size, create=True
|
|
)
|
|
|
|
# Then open it with another instance
|
|
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
|
|
self.assertFalse(reader_buffer.is_writer)
|
|
self.assertEqual(
|
|
reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name
|
|
)
|
|
|
|
def test_buffer_access(self):
|
|
"""Test accessing allocated buffers"""
|
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
|
data_buffer_size=self.buffer_size, create=True
|
|
)
|
|
|
|
size = 100
|
|
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
|
|
|
# Write some test data
|
|
test_data = b"Hello, World!" * 7 # 91 bytes
|
|
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
|
data_buf[0 : len(test_data)] = test_data
|
|
|
|
# Read it back
|
|
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
|
|
read_data = bytes(data_buf2[0 : len(test_data)])
|
|
read_id = metadata2[0]
|
|
|
|
self.assertEqual(read_data, test_data)
|
|
self.assertEqual(read_id, monotonic_id)
|
|
|
|
def test_memory_error_on_full_buffer(self):
|
|
"""Test that MemoryError is raised when buffer is full"""
|
|
small_buffer_size = 200
|
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
|
data_buffer_size=small_buffer_size, create=True
|
|
)
|
|
|
|
# Fill up the buffer
|
|
self.ring_buffer.allocate_buf(100)
|
|
self.ring_buffer.allocate_buf(80) # Total: 196 bytes used
|
|
|
|
# This should fail
|
|
with self.assertRaises(MemoryError):
|
|
self.ring_buffer.allocate_buf(1) # Would exceed buffer capacity
|
|
|
|
def test_allocation_and_free(self):
|
|
"""Test allocation and freeing of buffers"""
|
|
small_buffer_size = 200
|
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
|
data_buffer_size=small_buffer_size, create=True
|
|
)
|
|
|
|
size = 80
|
|
# Write some data
|
|
test_data = b"Repeated test data"
|
|
for i in range(5):
|
|
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
|
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
|
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
|
|
data_buf[4 : len(test_data) + 4] = test_data
|
|
print(self.ring_buffer.metadata)
|
|
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
|
|
print(f" Freed IDs: {freed_ids}")
|
|
self.assertEqual(freed_ids[0], i)
|
|
|
|
def test_clear_buffer(self):
|
|
"""Test clearing the buffer"""
|
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
|
data_buffer_size=self.buffer_size, create=True
|
|
)
|
|
|
|
# Allocate some buffers
|
|
for _ in range(3):
|
|
self.ring_buffer.allocate_buf(100)
|
|
|
|
# Clear the buffer
|
|
self.ring_buffer.clear()
|
|
|
|
# Check that metadata is empty and IDs reset
|
|
self.assertEqual(len(self.ring_buffer.metadata), 0)
|
|
self.assertEqual(self.ring_buffer.monotonic_id_start, 0)
|
|
self.assertEqual(self.ring_buffer.monotonic_id_end, 0)
|
|
self.assertEqual(self.ring_buffer.data_buffer_start, 0)
|
|
self.assertEqual(self.ring_buffer.data_buffer_end, 0)
|
|
|
|
def test_allocation_cycles(self):
|
|
buffer_size = 100
|
|
ring = SingleWriterShmRingBuffer(data_buffer_size=buffer_size, create=True)
|
|
|
|
# tracking allocations for assertions
|
|
allocated_bitmap = np.zeros(
|
|
(buffer_size,), dtype=np.bool_
|
|
) # addr -> is_allocated
|
|
allocation_map = dict() # monotonic_id -> (addr, size)
|
|
|
|
def count_allocated(bitmap) -> int:
|
|
return np.sum(bitmap).item()
|
|
|
|
def is_free_fn(a, b) -> bool:
|
|
return True
|
|
|
|
def mark_allocated_with_assertion(id, addr, size):
|
|
addr = addr % buffer_size
|
|
self.assertEqual(count_allocated(allocated_bitmap[addr : addr + size]), 0)
|
|
|
|
allocated_bitmap[addr : addr + size] = True
|
|
allocation_map[id] = (addr, size)
|
|
|
|
def mark_freed_with_assertion(id):
|
|
self.assertTrue(id in allocation_map)
|
|
|
|
addr, size = allocation_map.pop(id)
|
|
addr = addr % buffer_size
|
|
self.assertEqual(
|
|
count_allocated(allocated_bitmap[addr : addr + size]), size
|
|
)
|
|
|
|
allocated_bitmap[addr : addr + size] = False
|
|
|
|
def ring_free(free_size=None):
|
|
freed_ids = ring.free_buf(is_free_fn, free_size)
|
|
for freed_id in freed_ids:
|
|
mark_freed_with_assertion(freed_id)
|
|
|
|
def ring_allocate(allocate_size):
|
|
allocate_size_with_md = allocate_size + ring.MD_SIZE
|
|
try:
|
|
addr, monotonic_id = ring.allocate_buf(allocate_size)
|
|
mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md)
|
|
except MemoryError:
|
|
# free 2x size for enough space if wrapping happened
|
|
ring_free(allocate_size_with_md * 2)
|
|
|
|
# retry allocating
|
|
addr, monotonic_id = ring.allocate_buf(allocate_size)
|
|
mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md)
|
|
|
|
# 1. allocation & free cycles
|
|
for _ in range(33):
|
|
# will consume 2 + 8 = 10 bytes per allocation
|
|
ring_allocate(2)
|
|
|
|
# 2. free all allocations
|
|
ring_free()
|
|
|
|
# 3. try allocate the largest possible buffer
|
|
ring_allocate(buffer_size - ring.MD_SIZE)
|
|
|
|
|
|
def main():
|
|
"""Main function demonstrating usage and running tests"""
|
|
print("=== SingleWriterShmRingBuffer Test Suite ===\n")
|
|
|
|
# Run unit tests
|
|
print("Running unit tests...")
|
|
unittest.main(argv=[""], exit=False, verbosity=2)
|
|
|
|
print("\n" + "=" * 50)
|
|
print("=== Manual Demo ===\n")
|
|
|
|
# Manual demonstration
|
|
try:
|
|
print("Creating ring buffer...")
|
|
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True)
|
|
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
|
|
|
|
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
|
|
|
|
# Allocate some buffers
|
|
print("\nAllocating buffers...")
|
|
address_array = []
|
|
for i in range(3):
|
|
size = 100 + i * 50
|
|
try:
|
|
writer_buffer.free_buf(lambda *args: True)
|
|
address, monotonic_id = writer_buffer.allocate_buf(size)
|
|
address_array.append((address, size, monotonic_id))
|
|
|
|
# Write some test data
|
|
with writer_buffer.access_buf(address) as (data_buf, metadata):
|
|
test_message = f"Test message {i}".encode()
|
|
data_buf[0 : len(test_message)] = test_message
|
|
|
|
except MemoryError as e:
|
|
print(f" Failed to allocate {size} bytes: {e}")
|
|
|
|
print("\nBuffer state:")
|
|
print(f" Data buffer start: {writer_buffer.data_buffer_start}")
|
|
print(f" Data buffer end: {writer_buffer.data_buffer_end}")
|
|
print(f" Monotonic ID start: {writer_buffer.monotonic_id_start}")
|
|
print(f" Monotonic ID end: {writer_buffer.monotonic_id_end}")
|
|
print(f" Metadata entries: {len(writer_buffer.metadata)}")
|
|
|
|
# Try to read back the data
|
|
print("\nReading back data...")
|
|
for address, size, monotonic_id in address_array:
|
|
with reader_buffer.access_buf(address) as (data_buf, metadata):
|
|
# Find null terminator or read first 50 chars
|
|
data_bytes = bytes(data_buf[0:size])
|
|
message = data_bytes.decode()
|
|
print(f" ID {monotonic_id}: '{message}'")
|
|
|
|
except Exception as e:
|
|
print(f"Demo error: {e}")
|
|
traceback.print_exc()
|
|
|
|
print("\n=== Demo Complete ===")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|