mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 01:45:54 +08:00
[V1] Allow turning off pickle fallback in vllm.v1.serial_utils (#17427)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
739e03b344
commit
947f2f5375
@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||||
@ -196,3 +197,100 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
|||||||
assert torch.equal(obj1.large_non_contig_tensor,
|
assert torch.equal(obj1.large_non_contig_tensor,
|
||||||
obj2.large_non_contig_tensor)
|
obj2.large_non_contig_tensor)
|
||||||
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||||
|
def test_dict_serialization(allow_pickle: bool):
|
||||||
|
"""Test encoding and decoding of a generic Python object using pickle."""
|
||||||
|
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||||
|
decoder = MsgpackDecoder(allow_pickle=allow_pickle)
|
||||||
|
|
||||||
|
# Create a sample Python object
|
||||||
|
obj = {"key": "value", "number": 42}
|
||||||
|
|
||||||
|
# Encode the object
|
||||||
|
encoded = encoder.encode(obj)
|
||||||
|
|
||||||
|
# Decode the object
|
||||||
|
decoded = decoder.decode(encoded)
|
||||||
|
|
||||||
|
# Verify the decoded object matches the original
|
||||||
|
assert obj == decoded, "Decoded object does not match the original object."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||||
|
def test_tensor_serialization(allow_pickle: bool):
|
||||||
|
"""Test encoding and decoding of a torch.Tensor."""
|
||||||
|
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||||
|
decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle)
|
||||||
|
|
||||||
|
# Create a sample tensor
|
||||||
|
tensor = torch.rand(10, 10)
|
||||||
|
|
||||||
|
# Encode the tensor
|
||||||
|
encoded = encoder.encode(tensor)
|
||||||
|
|
||||||
|
# Decode the tensor
|
||||||
|
decoded = decoder.decode(encoded)
|
||||||
|
|
||||||
|
# Verify the decoded tensor matches the original
|
||||||
|
assert torch.allclose(
|
||||||
|
tensor, decoded), "Decoded tensor does not match the original tensor."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||||
|
def test_numpy_array_serialization(allow_pickle: bool):
|
||||||
|
"""Test encoding and decoding of a numpy array."""
|
||||||
|
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||||
|
decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle)
|
||||||
|
|
||||||
|
# Create a sample numpy array
|
||||||
|
array = np.random.rand(10, 10)
|
||||||
|
|
||||||
|
# Encode the numpy array
|
||||||
|
encoded = encoder.encode(array)
|
||||||
|
|
||||||
|
# Decode the numpy array
|
||||||
|
decoded = decoder.decode(encoded)
|
||||||
|
|
||||||
|
# Verify the decoded array matches the original
|
||||||
|
assert np.allclose(
|
||||||
|
array,
|
||||||
|
decoded), "Decoded numpy array does not match the original array."
|
||||||
|
|
||||||
|
|
||||||
|
class CustomClass:
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, CustomClass) and self.value == other.value
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_class_serialization_allowed_with_pickle():
|
||||||
|
"""Test that serializing a custom class succeeds when allow_pickle=True."""
|
||||||
|
encoder = MsgpackEncoder(allow_pickle=True)
|
||||||
|
decoder = MsgpackDecoder(CustomClass, allow_pickle=True)
|
||||||
|
|
||||||
|
obj = CustomClass("test_value")
|
||||||
|
|
||||||
|
# Encode the custom class
|
||||||
|
encoded = encoder.encode(obj)
|
||||||
|
|
||||||
|
# Decode the custom class
|
||||||
|
decoded = decoder.decode(encoded)
|
||||||
|
|
||||||
|
# Verify the decoded object matches the original
|
||||||
|
assert obj == decoded, "Decoded object does not match the original object."
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_class_serialization_disallowed_without_pickle():
|
||||||
|
"""Test that serializing a custom class fails when allow_pickle=False."""
|
||||||
|
encoder = MsgpackEncoder(allow_pickle=False)
|
||||||
|
|
||||||
|
obj = CustomClass("test_value")
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# Attempt to encode the custom class
|
||||||
|
encoder.encode(obj)
|
||||||
|
|||||||
@ -47,7 +47,9 @@ class MsgpackEncoder:
|
|||||||
via dedicated messages. Note that this is a per-tensor limit.
|
via dedicated messages. Note that this is a per-tensor limit.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size_threshold: Optional[int] = None):
|
def __init__(self,
|
||||||
|
size_threshold: Optional[int] = None,
|
||||||
|
allow_pickle: bool = True):
|
||||||
if size_threshold is None:
|
if size_threshold is None:
|
||||||
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
||||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||||
@ -56,6 +58,7 @@ class MsgpackEncoder:
|
|||||||
# pass custom data to the hook otherwise.
|
# pass custom data to the hook otherwise.
|
||||||
self.aux_buffers: Optional[list[bytestr]] = None
|
self.aux_buffers: Optional[list[bytestr]] = None
|
||||||
self.size_threshold = size_threshold
|
self.size_threshold = size_threshold
|
||||||
|
self.allow_pickle = allow_pickle
|
||||||
|
|
||||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||||
try:
|
try:
|
||||||
@ -105,6 +108,9 @@ class MsgpackEncoder:
|
|||||||
for itemlist in mm._items_by_modality.values()
|
for itemlist in mm._items_by_modality.values()
|
||||||
for item in itemlist]
|
for item in itemlist]
|
||||||
|
|
||||||
|
if not self.allow_pickle:
|
||||||
|
raise TypeError(f"Object of type {type(obj)} is not serializable")
|
||||||
|
|
||||||
if isinstance(obj, FunctionType):
|
if isinstance(obj, FunctionType):
|
||||||
# `pickle` is generally faster than cloudpickle, but can have
|
# `pickle` is generally faster than cloudpickle, but can have
|
||||||
# problems serializing methods.
|
# problems serializing methods.
|
||||||
@ -179,12 +185,13 @@ class MsgpackDecoder:
|
|||||||
not thread-safe when encoding tensors / numpy arrays.
|
not thread-safe when encoding tensors / numpy arrays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, t: Optional[Any] = None):
|
def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True):
|
||||||
args = () if t is None else (t, )
|
args = () if t is None else (t, )
|
||||||
self.decoder = msgpack.Decoder(*args,
|
self.decoder = msgpack.Decoder(*args,
|
||||||
ext_hook=self.ext_hook,
|
ext_hook=self.ext_hook,
|
||||||
dec_hook=self.dec_hook)
|
dec_hook=self.dec_hook)
|
||||||
self.aux_buffers: Sequence[bytestr] = ()
|
self.aux_buffers: Sequence[bytestr] = ()
|
||||||
|
self.allow_pickle = allow_pickle
|
||||||
|
|
||||||
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
|
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
|
||||||
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
|
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
|
||||||
@ -265,6 +272,8 @@ class MsgpackDecoder:
|
|||||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||||
if code == CUSTOM_TYPE_RAW_VIEW:
|
if code == CUSTOM_TYPE_RAW_VIEW:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
if self.allow_pickle:
|
||||||
if code == CUSTOM_TYPE_PICKLE:
|
if code == CUSTOM_TYPE_PICKLE:
|
||||||
return pickle.loads(data)
|
return pickle.loads(data)
|
||||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user