diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index b55018ae8ef03..d1271b210ad88 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -5,6 +5,7 @@ from typing import Optional import msgspec import numpy as np +import pytest import torch from vllm.multimodal.inputs import (MultiModalBatchedField, @@ -196,3 +197,100 @@ def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor) assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) + + +@pytest.mark.parametrize("allow_pickle", [True, False]) +def test_dict_serialization(allow_pickle: bool): + """Test encoding and decoding of a generic Python object using pickle.""" + encoder = MsgpackEncoder(allow_pickle=allow_pickle) + decoder = MsgpackDecoder(allow_pickle=allow_pickle) + + # Create a sample Python object + obj = {"key": "value", "number": 42} + + # Encode the object + encoded = encoder.encode(obj) + + # Decode the object + decoded = decoder.decode(encoded) + + # Verify the decoded object matches the original + assert obj == decoded, "Decoded object does not match the original object." + + +@pytest.mark.parametrize("allow_pickle", [True, False]) +def test_tensor_serialization(allow_pickle: bool): + """Test encoding and decoding of a torch.Tensor.""" + encoder = MsgpackEncoder(allow_pickle=allow_pickle) + decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle) + + # Create a sample tensor + tensor = torch.rand(10, 10) + + # Encode the tensor + encoded = encoder.encode(tensor) + + # Decode the tensor + decoded = decoder.decode(encoded) + + # Verify the decoded tensor matches the original + assert torch.allclose( + tensor, decoded), "Decoded tensor does not match the original tensor." + + +@pytest.mark.parametrize("allow_pickle", [True, False]) +def test_numpy_array_serialization(allow_pickle: bool): + """Test encoding and decoding of a numpy array.""" + encoder = MsgpackEncoder(allow_pickle=allow_pickle) + decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle) + + # Create a sample numpy array + array = np.random.rand(10, 10) + + # Encode the numpy array + encoded = encoder.encode(array) + + # Decode the numpy array + decoded = decoder.decode(encoded) + + # Verify the decoded array matches the original + assert np.allclose( + array, + decoded), "Decoded numpy array does not match the original array." + + +class CustomClass: + + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return isinstance(other, CustomClass) and self.value == other.value + + +def test_custom_class_serialization_allowed_with_pickle(): + """Test that serializing a custom class succeeds when allow_pickle=True.""" + encoder = MsgpackEncoder(allow_pickle=True) + decoder = MsgpackDecoder(CustomClass, allow_pickle=True) + + obj = CustomClass("test_value") + + # Encode the custom class + encoded = encoder.encode(obj) + + # Decode the custom class + decoded = decoder.decode(encoded) + + # Verify the decoded object matches the original + assert obj == decoded, "Decoded object does not match the original object." + + +def test_custom_class_serialization_disallowed_without_pickle(): + """Test that serializing a custom class fails when allow_pickle=False.""" + encoder = MsgpackEncoder(allow_pickle=False) + + obj = CustomClass("test_value") + + with pytest.raises(TypeError): + # Attempt to encode the custom class + encoder.encode(obj) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index a3ad8cb920962..e00ecde66af08 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -47,7 +47,9 @@ class MsgpackEncoder: via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self, size_threshold: Optional[int] = None): + def __init__(self, + size_threshold: Optional[int] = None, + allow_pickle: bool = True): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) @@ -56,6 +58,7 @@ class MsgpackEncoder: # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None self.size_threshold = size_threshold + self.allow_pickle = allow_pickle def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -105,6 +108,9 @@ class MsgpackEncoder: for itemlist in mm._items_by_modality.values() for item in itemlist] + if not self.allow_pickle: + raise TypeError(f"Object of type {type(obj)} is not serializable") + if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. @@ -179,12 +185,13 @@ class MsgpackDecoder: not thread-safe when encoding tensors / numpy arrays. """ - def __init__(self, t: Optional[Any] = None): + def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True): args = () if t is None else (t, ) self.decoder = msgpack.Decoder(*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook) self.aux_buffers: Sequence[bytestr] = () + self.allow_pickle = allow_pickle def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): @@ -265,10 +272,12 @@ class MsgpackDecoder: def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data - if code == CUSTOM_TYPE_PICKLE: - return pickle.loads(data) - if code == CUSTOM_TYPE_CLOUDPICKLE: - return cloudpickle.loads(data) + + if self.allow_pickle: + if code == CUSTOM_TYPE_PICKLE: + return pickle.loads(data) + if code == CUSTOM_TYPE_CLOUDPICKLE: + return cloudpickle.loads(data) raise NotImplementedError( f"Extension type code {code} is not supported")