[BugFix] Avoid race conditions in zero-copy tensor transmission (#17203)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-25 23:00:07 -07:00 committed by GitHub
parent 53e8cf53a4
commit b07bf83c7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 76 additions and 11 deletions

View File

@ -32,6 +32,7 @@ class MyType:
large_f_contig_tensor: torch.Tensor large_f_contig_tensor: torch.Tensor
small_non_contig_tensor: torch.Tensor small_non_contig_tensor: torch.Tensor
large_non_contig_tensor: torch.Tensor large_non_contig_tensor: torch.Tensor
empty_tensor: torch.Tensor
def test_encode_decode(): def test_encode_decode():
@ -58,6 +59,7 @@ def test_encode_decode():
large_f_contig_tensor=torch.rand(1024, 4).t(), large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
empty_tensor=torch.empty(0),
) )
encoder = MsgpackEncoder(size_threshold=256) encoder = MsgpackEncoder(size_threshold=256)
@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2.small_non_contig_tensor) obj2.small_non_contig_tensor)
assert torch.equal(obj1.large_non_contig_tensor, assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor) obj2.large_non_contig_tensor)
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)

View File

@ -5,6 +5,7 @@ import signal
import sys import sys
import threading import threading
import time import time
from collections import deque
from concurrent.futures import Future from concurrent.futures import Future
from inspect import isclass, signature from inspect import isclass, signature
from logging import DEBUG from logging import DEBUG
@ -527,8 +528,12 @@ class EngineCoreProc(EngineCore):
# Msgpack serialization encoding. # Msgpack serialization encoding.
encoder = MsgpackEncoder() encoder = MsgpackEncoder()
# Reuse send buffer. # Send buffers to reuse.
buffer = bytearray() reuse_buffers: list[bytearray] = []
# Keep references to outputs and buffers until zmq is finished
# with them (outputs may contain tensors/np arrays whose
# backing buffers were extracted for zero-copy send).
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
# We must set linger to ensure the ENGINE_CORE_DEAD # We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket. # message is sent prior to closing the socket.
@ -541,8 +546,22 @@ class EngineCoreProc(EngineCore):
break break
assert not isinstance(outputs, bytes) assert not isinstance(outputs, bytes)
outputs.engine_index = engine_index outputs.engine_index = engine_index
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer) buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False) tracker = socket.send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < 2:
# Keep at most 2 buffers to reuse.
reuse_buffers.append(buffer)
class DPEngineCoreProc(EngineCoreProc): class DPEngineCoreProc(EngineCoreProc):

View File

@ -1,9 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import contextlib
import queue import queue
import uuid import uuid
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Awaitable, Sequence from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -396,6 +398,12 @@ class MPClient(EngineCoreClient):
self._wait_for_engine_startup() self._wait_for_engine_startup()
self.utility_results: dict[int, AnyFuture] = {} self.utility_results: dict[int, AnyFuture] = {}
# Request objects which may contain pytorch-allocated tensors
# that we need to keep references to until zmq is done with the
# underlying data.
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
success = True success = True
finally: finally:
if not success: if not success:
@ -459,6 +467,14 @@ class MPClient(EngineCoreClient):
if self.resources.engine_dead: if self.resources.engine_dead:
raise EngineDeadError() raise EngineDeadError()
def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any):
if not tracker.done:
self.pending_messages.appendleft((tracker, msg))
def free_pending_messages(self):
while self.pending_messages and self.pending_messages[-1][0].done:
self.pending_messages.pop()
def _process_utility_output(output: UtilityOutput, def _process_utility_output(output: UtilityOutput,
utility_results: dict[int, AnyFuture]): utility_results: dict[int, AnyFuture]):
@ -544,10 +560,18 @@ class SyncMPClient(MPClient):
def _send_input(self, request_type: EngineCoreRequestType, request: Any): def _send_input(self, request_type: EngineCoreRequestType, request: Any):
self.ensure_alive() self.ensure_alive()
self.free_pending_messages()
# (Identity, RequestType, SerializedRequest) # (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value, msg = (self.core_engine.identity, request_type.value,
*self.encoder.encode(request)) *self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
if len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
self.input_socket.send_multipart(msg, copy=False)
return
tracker = self.input_socket.send_multipart(msg, copy=False, track=True)
self.add_pending_message(tracker, request)
def call_utility(self, method: str, *args) -> Any: def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64 call_id = uuid.uuid1().int >> 64
@ -698,19 +722,38 @@ class AsyncMPClient(MPClient):
def _send_input(self, def _send_input(self,
request_type: EngineCoreRequestType, request_type: EngineCoreRequestType,
request: Any, request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]: engine: Optional[CoreEngine] = None) -> Awaitable[Any]:
self.ensure_alive() self.ensure_alive()
if engine is None: if engine is None:
engine = self.core_engine engine = self.core_engine
message = (request_type.value, *self.encoder.encode(request)) message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine) return self._send_input_message(message, engine, request)
def _send_input_message(self, message: tuple[bytestr, ...], def _send_input_message(self, message: tuple[bytestr,
engine: CoreEngine) -> Awaitable[None]: ...], engine: CoreEngine,
objects: Any) -> Awaitable[Any]:
"""
objects is a reference to retain until zmq is finished with the
buffers, in case they were extracted from tensors in the request.
"""
self.ensure_alive() self.ensure_alive()
message = (engine.identity, ) + message self.free_pending_messages()
return self.input_socket.send_multipart(message, copy=False)
msg = (engine.identity, ) + message
if not objects or len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
return self.input_socket.send_multipart(msg, copy=False)
future: asyncio.Future[zmq.MessageTracker]
future = self.input_socket.send_multipart(msg, copy=False, track=True)
def add_pending(f: asyncio.Future[zmq.MessageTracker]):
with contextlib.suppress(BaseException):
self.add_pending_message(f.result(), objects)
future.add_done_callback(add_pending)
return future
async def call_utility_async(self, method: str, *args) -> Any: async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method, return await self._call_utility_async(method,
@ -724,7 +767,7 @@ class AsyncMPClient(MPClient):
self.utility_results[call_id] = future self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args))) (call_id, method, args)))
await self._send_input_message(message, engine) await self._send_input_message(message, engine, args)
self._ensure_output_queue_task() self._ensure_output_queue_task()
return await future return await future