mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:45:50 +08:00
[BugFix] Avoid race conditions in zero-copy tensor transmission (#17203)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
53e8cf53a4
commit
b07bf83c7d
@ -32,6 +32,7 @@ class MyType:
|
|||||||
large_f_contig_tensor: torch.Tensor
|
large_f_contig_tensor: torch.Tensor
|
||||||
small_non_contig_tensor: torch.Tensor
|
small_non_contig_tensor: torch.Tensor
|
||||||
large_non_contig_tensor: torch.Tensor
|
large_non_contig_tensor: torch.Tensor
|
||||||
|
empty_tensor: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def test_encode_decode():
|
def test_encode_decode():
|
||||||
@ -58,6 +59,7 @@ def test_encode_decode():
|
|||||||
large_f_contig_tensor=torch.rand(1024, 4).t(),
|
large_f_contig_tensor=torch.rand(1024, 4).t(),
|
||||||
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
|
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
|
||||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||||
|
empty_tensor=torch.empty(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder = MsgpackEncoder(size_threshold=256)
|
encoder = MsgpackEncoder(size_threshold=256)
|
||||||
@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
|||||||
obj2.small_non_contig_tensor)
|
obj2.small_non_contig_tensor)
|
||||||
assert torch.equal(obj1.large_non_contig_tensor,
|
assert torch.equal(obj1.large_non_contig_tensor,
|
||||||
obj2.large_non_contig_tensor)
|
obj2.large_non_contig_tensor)
|
||||||
|
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from inspect import isclass, signature
|
from inspect import isclass, signature
|
||||||
from logging import DEBUG
|
from logging import DEBUG
|
||||||
@ -527,8 +528,12 @@ class EngineCoreProc(EngineCore):
|
|||||||
|
|
||||||
# Msgpack serialization encoding.
|
# Msgpack serialization encoding.
|
||||||
encoder = MsgpackEncoder()
|
encoder = MsgpackEncoder()
|
||||||
# Reuse send buffer.
|
# Send buffers to reuse.
|
||||||
buffer = bytearray()
|
reuse_buffers: list[bytearray] = []
|
||||||
|
# Keep references to outputs and buffers until zmq is finished
|
||||||
|
# with them (outputs may contain tensors/np arrays whose
|
||||||
|
# backing buffers were extracted for zero-copy send).
|
||||||
|
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
|
||||||
|
|
||||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||||
# message is sent prior to closing the socket.
|
# message is sent prior to closing the socket.
|
||||||
@ -541,8 +546,22 @@ class EngineCoreProc(EngineCore):
|
|||||||
break
|
break
|
||||||
assert not isinstance(outputs, bytes)
|
assert not isinstance(outputs, bytes)
|
||||||
outputs.engine_index = engine_index
|
outputs.engine_index = engine_index
|
||||||
|
|
||||||
|
# Reclaim buffers that zmq is finished with.
|
||||||
|
while pending and pending[-1][0].done:
|
||||||
|
reuse_buffers.append(pending.pop()[2])
|
||||||
|
|
||||||
|
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||||
buffers = encoder.encode_into(outputs, buffer)
|
buffers = encoder.encode_into(outputs, buffer)
|
||||||
socket.send_multipart(buffers, copy=False)
|
tracker = socket.send_multipart(buffers,
|
||||||
|
copy=False,
|
||||||
|
track=True)
|
||||||
|
if not tracker.done:
|
||||||
|
ref = outputs if len(buffers) > 1 else None
|
||||||
|
pending.appendleft((tracker, ref, buffer))
|
||||||
|
elif len(reuse_buffers) < 2:
|
||||||
|
# Keep at most 2 buffers to reuse.
|
||||||
|
reuse_buffers.append(buffer)
|
||||||
|
|
||||||
|
|
||||||
class DPEngineCoreProc(EngineCoreProc):
|
class DPEngineCoreProc(EngineCoreProc):
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import queue
|
import queue
|
||||||
import uuid
|
import uuid
|
||||||
import weakref
|
import weakref
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import deque
|
||||||
from collections.abc import Awaitable, Sequence
|
from collections.abc import Awaitable, Sequence
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -396,6 +398,12 @@ class MPClient(EngineCoreClient):
|
|||||||
self._wait_for_engine_startup()
|
self._wait_for_engine_startup()
|
||||||
|
|
||||||
self.utility_results: dict[int, AnyFuture] = {}
|
self.utility_results: dict[int, AnyFuture] = {}
|
||||||
|
|
||||||
|
# Request objects which may contain pytorch-allocated tensors
|
||||||
|
# that we need to keep references to until zmq is done with the
|
||||||
|
# underlying data.
|
||||||
|
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
finally:
|
finally:
|
||||||
if not success:
|
if not success:
|
||||||
@ -459,6 +467,14 @@ class MPClient(EngineCoreClient):
|
|||||||
if self.resources.engine_dead:
|
if self.resources.engine_dead:
|
||||||
raise EngineDeadError()
|
raise EngineDeadError()
|
||||||
|
|
||||||
|
def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any):
|
||||||
|
if not tracker.done:
|
||||||
|
self.pending_messages.appendleft((tracker, msg))
|
||||||
|
|
||||||
|
def free_pending_messages(self):
|
||||||
|
while self.pending_messages and self.pending_messages[-1][0].done:
|
||||||
|
self.pending_messages.pop()
|
||||||
|
|
||||||
|
|
||||||
def _process_utility_output(output: UtilityOutput,
|
def _process_utility_output(output: UtilityOutput,
|
||||||
utility_results: dict[int, AnyFuture]):
|
utility_results: dict[int, AnyFuture]):
|
||||||
@ -544,10 +560,18 @@ class SyncMPClient(MPClient):
|
|||||||
|
|
||||||
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
||||||
self.ensure_alive()
|
self.ensure_alive()
|
||||||
|
self.free_pending_messages()
|
||||||
# (Identity, RequestType, SerializedRequest)
|
# (Identity, RequestType, SerializedRequest)
|
||||||
msg = (self.core_engine.identity, request_type.value,
|
msg = (self.core_engine.identity, request_type.value,
|
||||||
*self.encoder.encode(request))
|
*self.encoder.encode(request))
|
||||||
self.input_socket.send_multipart(msg, copy=False)
|
|
||||||
|
if len(msg) <= 3:
|
||||||
|
# No auxiliary buffers => no tensor backing buffers in request.
|
||||||
|
self.input_socket.send_multipart(msg, copy=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
tracker = self.input_socket.send_multipart(msg, copy=False, track=True)
|
||||||
|
self.add_pending_message(tracker, request)
|
||||||
|
|
||||||
def call_utility(self, method: str, *args) -> Any:
|
def call_utility(self, method: str, *args) -> Any:
|
||||||
call_id = uuid.uuid1().int >> 64
|
call_id = uuid.uuid1().int >> 64
|
||||||
@ -698,19 +722,38 @@ class AsyncMPClient(MPClient):
|
|||||||
def _send_input(self,
|
def _send_input(self,
|
||||||
request_type: EngineCoreRequestType,
|
request_type: EngineCoreRequestType,
|
||||||
request: Any,
|
request: Any,
|
||||||
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
|
engine: Optional[CoreEngine] = None) -> Awaitable[Any]:
|
||||||
self.ensure_alive()
|
self.ensure_alive()
|
||||||
if engine is None:
|
if engine is None:
|
||||||
engine = self.core_engine
|
engine = self.core_engine
|
||||||
|
|
||||||
message = (request_type.value, *self.encoder.encode(request))
|
message = (request_type.value, *self.encoder.encode(request))
|
||||||
return self._send_input_message(message, engine)
|
return self._send_input_message(message, engine, request)
|
||||||
|
|
||||||
def _send_input_message(self, message: tuple[bytestr, ...],
|
def _send_input_message(self, message: tuple[bytestr,
|
||||||
engine: CoreEngine) -> Awaitable[None]:
|
...], engine: CoreEngine,
|
||||||
|
objects: Any) -> Awaitable[Any]:
|
||||||
|
"""
|
||||||
|
objects is a reference to retain until zmq is finished with the
|
||||||
|
buffers, in case they were extracted from tensors in the request.
|
||||||
|
"""
|
||||||
self.ensure_alive()
|
self.ensure_alive()
|
||||||
message = (engine.identity, ) + message
|
self.free_pending_messages()
|
||||||
return self.input_socket.send_multipart(message, copy=False)
|
|
||||||
|
msg = (engine.identity, ) + message
|
||||||
|
if not objects or len(msg) <= 3:
|
||||||
|
# No auxiliary buffers => no tensor backing buffers in request.
|
||||||
|
return self.input_socket.send_multipart(msg, copy=False)
|
||||||
|
|
||||||
|
future: asyncio.Future[zmq.MessageTracker]
|
||||||
|
future = self.input_socket.send_multipart(msg, copy=False, track=True)
|
||||||
|
|
||||||
|
def add_pending(f: asyncio.Future[zmq.MessageTracker]):
|
||||||
|
with contextlib.suppress(BaseException):
|
||||||
|
self.add_pending_message(f.result(), objects)
|
||||||
|
|
||||||
|
future.add_done_callback(add_pending)
|
||||||
|
return future
|
||||||
|
|
||||||
async def call_utility_async(self, method: str, *args) -> Any:
|
async def call_utility_async(self, method: str, *args) -> Any:
|
||||||
return await self._call_utility_async(method,
|
return await self._call_utility_async(method,
|
||||||
@ -724,7 +767,7 @@ class AsyncMPClient(MPClient):
|
|||||||
self.utility_results[call_id] = future
|
self.utility_results[call_id] = future
|
||||||
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
||||||
(call_id, method, args)))
|
(call_id, method, args)))
|
||||||
await self._send_input_message(message, engine)
|
await self._send_input_message(message, engine, args)
|
||||||
self._ensure_output_queue_task()
|
self._ensure_output_queue_task()
|
||||||
return await future
|
return await future
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user