diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py new file mode 100644 index 0000000000000..0fc3b074533da --- /dev/null +++ b/tests/v1/test_serial_utils.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections import UserDict +from dataclasses import dataclass + +import numpy as np +import torch + +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + + +class UnrecognizedType(UserDict): + + def __init__(self, an_int: int): + super().__init__() + self.an_int = an_int + + +@dataclass +class MyType: + tensor1: torch.Tensor + a_string: str + list_of_tensors: list[torch.Tensor] + numpy_array: np.ndarray + unrecognized: UnrecognizedType + + +def test_encode_decode(): + """Test encode/decode loop with zero-copy tensors.""" + + obj = MyType( + tensor1=torch.randint(low=0, + high=100, + size=(1024, ), + dtype=torch.int32), + a_string="hello", + list_of_tensors=[ + torch.rand((1, 10), dtype=torch.float32), + torch.rand((3, 5, 4000), dtype=torch.float64), + torch.tensor(1984), # test scalar too + ], + numpy_array=np.arange(512), + unrecognized=UnrecognizedType(33), + ) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyType) + + encoded = encoder.encode(obj) + + # There should be the main buffer + 2 large tensor buffers + # + 1 large numpy array. "large" is <= 256 bytes. + # The two small tensors are encoded inline. + assert len(encoded) == 4 + + decoded: MyType = decoder.decode(encoded) + + assert_equal(decoded, obj) + + # Test encode_into case + + preallocated = bytearray() + + encoded2 = encoder.encode_into(obj, preallocated) + + assert len(encoded2) == 4 + assert encoded2[0] is preallocated + + decoded2: MyType = decoder.decode(encoded2) + + assert_equal(decoded2, obj) + + +def assert_equal(obj1: MyType, obj2: MyType): + assert torch.equal(obj1.tensor1, obj2.tensor1) + assert obj1.a_string == obj2.a_string + assert all( + torch.equal(a, b) + for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) + assert np.array_equal(obj1.numpy_array, obj2.numpy_array) + assert obj1.unrecognized.an_int == obj2.unrecognized.an_int diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 077d499889623..b8c2bebbc5ecb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -490,14 +490,14 @@ class EngineCoreProc(EngineCore): while True: # (RequestType, RequestData) - type_frame, data_frame = socket.recv_multipart(copy=False) + type_frame, *data_frames = socket.recv_multipart(copy=False) request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. decoder = add_request_decoder if ( request_type == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frame.buffer) + request = decoder.decode(data_frames) # Push to input queue for core busy loop. self.input_queue.put_nowait((request_type, request)) @@ -514,8 +514,8 @@ class EngineCoreProc(EngineCore): while True: outputs = self.output_queue.get() outputs.engine_index = engine_index - encoder.encode_into(outputs, buffer) - socket.send(buffer, copy=False) + buffers = encoder.encode_into(outputs, buffer) + socket.send_multipart(buffers, copy=False) ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 2e5f9021f1009..a96ebc7edb538 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) @@ -505,8 +505,8 @@ class SyncMPClient(MPClient): # shutdown signal, exit thread. break - frame = out_socket.recv(copy=False) - outputs = decoder.decode(frame.buffer) + frames = out_socket.recv_multipart(copy=False) + outputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) @@ -529,7 +529,7 @@ class SyncMPClient(MPClient): def _send_input(self, request_type: EngineCoreRequestType, request: Any): # (Identity, RequestType, SerializedRequest) msg = (self.core_engine.identity, request_type.value, - self.encoder.encode(request)) + *self.encoder.encode(request)) self.input_socket.send_multipart(msg, copy=False) def call_utility(self, method: str, *args) -> Any: @@ -633,8 +633,8 @@ class AsyncMPClient(MPClient): async def process_outputs_socket(): while True: - (frame, ) = await output_socket.recv_multipart(copy=False) - outputs: EngineCoreOutputs = decoder.decode(frame.buffer) + frames = await output_socket.recv_multipart(copy=False) + outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) @@ -666,12 +666,12 @@ class AsyncMPClient(MPClient): if engine is None: engine = self.core_engine - message = (request_type.value, self.encoder.encode(request)) + message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine) - def _send_input_message(self, message: tuple[bytes, bytes], + def _send_input_message(self, message: tuple[bytestr, ...], engine: CoreEngine) -> Awaitable[None]: - message = (engine.identity, ) + message # type: ignore[assignment] + message = (engine.identity, ) + message return self.input_socket.send_multipart(message, copy=False) async def call_utility_async(self, method: str, *args) -> Any: @@ -684,8 +684,8 @@ class AsyncMPClient(MPClient): call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, - self.encoder.encode((call_id, method, args))) + message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( + (call_id, method, args))) await self._send_input_message(message, engine) self._ensure_output_queue_task() return await future @@ -760,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient): # Control message used for triggering dp idle mode loop. self.start_dp_msg = (EngineCoreRequestType.START_DP.value, - self.encoder.encode(None)) + *self.encoder.encode(None)) self.num_engines_running = 0 self.reqs_in_flight: dict[str, CoreEngine] = {} @@ -794,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient): # tokenized. request.prompt = None - msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request)) + msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request)) chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 146d7d747f1a4..99b352fdef80a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,61 +1,140 @@ # SPDX-License-Identifier: Apache-2.0 import pickle +from collections.abc import Sequence +from inspect import isclass from types import FunctionType -from typing import Any, Optional +from typing import Any, Optional, Union import cloudpickle +import numpy as np import torch +import zmq from msgspec import msgpack -CUSTOM_TYPE_TENSOR = 1 -CUSTOM_TYPE_PICKLE = 2 -CUSTOM_TYPE_CLOUDPICKLE = 3 +CUSTOM_TYPE_PICKLE = 1 +CUSTOM_TYPE_CLOUDPICKLE = 2 + +# TODO calibrate this size +INLINE_BUF_SIZE_THRESHOLD = 256 + +bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] class MsgpackEncoder: - """Encoder with custom torch tensor serialization.""" + """Encoder with custom torch tensor and numpy array serialization. + + Note that unlike vanilla `msgspec` Encoders, this interface is generally + not thread-safe when encoding tensors / numpy arrays. + """ def __init__(self): - self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook) + self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) + # This is used as a local stash of buffers that we can then access from + # our custom `msgspec` hook, `enc_hook`. We don't have a way to + # pass custom data to the hook otherwise. + self.aux_buffers: Optional[list[bytestr]] = None - def encode(self, obj: Any) -> bytes: - return self.encoder.encode(obj) + def encode(self, obj: Any) -> Sequence[bytestr]: + try: + self.aux_buffers = bufs = [b''] + bufs[0] = self.encoder.encode(obj) + # This `bufs` list allows us to collect direct pointers to backing + # buffers of tensors and np arrays, and return them along with the + # top-level encoded buffer instead of copying their data into the + # new buffer. + return bufs + finally: + self.aux_buffers = None - def encode_into(self, obj: Any, buf: bytearray) -> None: - self.encoder.encode_into(obj, buf) + def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: + try: + self.aux_buffers = [buf] + bufs = self.aux_buffers + self.encoder.encode_into(obj, buf) + return bufs + finally: + self.aux_buffers = None + + def enc_hook(self, obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + return self._encode_ndarray(obj.numpy()) + + # Fall back to pickle for object or void kind ndarrays. + if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): + return self._encode_ndarray(obj) + + if isinstance(obj, FunctionType): + # `pickle` is generally faster than cloudpickle, but can have + # problems serializing methods. + return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) + + return msgpack.Ext(CUSTOM_TYPE_PICKLE, + pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + + def _encode_ndarray( + self, obj: np.ndarray + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + assert self.aux_buffers is not None + if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD: + # Encode small arrays and scalars inline. + data = obj.data + else: + # Otherwise encode index of backing buffer. + obj = np.ascontiguousarray(obj) + data = len(self.aux_buffers) + self.aux_buffers.append(obj.data) + # We serialize the ndarray as a tuple of native types. + # The data is either inlined if small, or an index into a list of + # backing buffers that we've stashed in `aux_buffers`. + return obj.dtype.str, obj.shape, data class MsgpackDecoder: - """Decoder with custom torch tensor serialization.""" + """Decoder with custom torch tensor and numpy array serialization. + + Note that unlike vanilla `msgspec` Decoders, this interface is generally + not thread-safe when encoding tensors / numpy arrays. + """ def __init__(self, t: Optional[Any] = None): args = () if t is None else (t, ) - self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook) + self.decoder = msgpack.Decoder(*args, + ext_hook=self.ext_hook, + dec_hook=self.dec_hook) + self.aux_buffers: Sequence[bytestr] = () - def decode(self, obj: Any): - return self.decoder.decode(obj) + def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: + if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): + # TODO - This check can become `isinstance(bufs, bytestr)` + # as of Python 3.10. + return self.decoder.decode(bufs) + self.aux_buffers = bufs + try: + return self.decoder.decode(bufs[0]) + finally: + self.aux_buffers = () -def custom_enc_hook(obj: Any) -> Any: - if isinstance(obj, torch.Tensor): - # NOTE(rob): it is fastest to use numpy + pickle - # when serializing torch tensors. - # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 - return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) + def dec_hook(self, t: type, obj: Any) -> Any: + # Given native types in `obj`, convert to type `t`. + if isclass(t): + if issubclass(t, np.ndarray): + return self._decode_ndarray(obj) + if issubclass(t, torch.Tensor): + return torch.from_numpy(self._decode_ndarray(obj)) + return obj - if isinstance(obj, FunctionType): - return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) + def _decode_ndarray(self, arr: Any) -> np.ndarray: + dtype, shape, data = arr + buffer = self.aux_buffers[data] if isinstance(data, int) else data + return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) - return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) + def ext_hook(self, code: int, data: memoryview) -> Any: + if code == CUSTOM_TYPE_PICKLE: + return pickle.loads(data) + if code == CUSTOM_TYPE_CLOUDPICKLE: + return cloudpickle.loads(data) - -def custom_ext_hook(code: int, data: memoryview) -> Any: - if code == CUSTOM_TYPE_TENSOR: - return torch.from_numpy(pickle.loads(data)) - if code == CUSTOM_TYPE_PICKLE: - return pickle.loads(data) - if code == CUSTOM_TYPE_CLOUDPICKLE: - return cloudpickle.loads(data) - - raise NotImplementedError(f"Extension type code {code} is not supported") + raise NotImplementedError( + f"Extension type code {code} is not supported")