mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:15:39 +08:00
Serialize tensors using int8 views (#16866)
Signed-off-by: Staszek Pasko <staszek@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
682e0b6d2f
commit
87aaadef73
@ -47,6 +47,10 @@ def test_encode_decode():
|
|||||||
torch.rand((1, 10), dtype=torch.float32),
|
torch.rand((1, 10), dtype=torch.float32),
|
||||||
torch.rand((3, 5, 4000), dtype=torch.float64),
|
torch.rand((3, 5, 4000), dtype=torch.float64),
|
||||||
torch.tensor(1984), # test scalar too
|
torch.tensor(1984), # test scalar too
|
||||||
|
# Make sure to test bf16 which numpy doesn't support.
|
||||||
|
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
|
||||||
|
torch.tensor([float("-inf"), float("inf")] * 1024,
|
||||||
|
dtype=torch.bfloat16),
|
||||||
],
|
],
|
||||||
numpy_array=np.arange(512),
|
numpy_array=np.arange(512),
|
||||||
unrecognized=UnrecognizedType(33),
|
unrecognized=UnrecognizedType(33),
|
||||||
@ -64,7 +68,7 @@ def test_encode_decode():
|
|||||||
# There should be the main buffer + 4 large tensor buffers
|
# There should be the main buffer + 4 large tensor buffers
|
||||||
# + 1 large numpy array. "large" is <= 512 bytes.
|
# + 1 large numpy array. "large" is <= 512 bytes.
|
||||||
# The two small tensors are encoded inline.
|
# The two small tensors are encoded inline.
|
||||||
assert len(encoded) == 6
|
assert len(encoded) == 8
|
||||||
|
|
||||||
decoded: MyType = decoder.decode(encoded)
|
decoded: MyType = decoder.decode(encoded)
|
||||||
|
|
||||||
@ -76,7 +80,7 @@ def test_encode_decode():
|
|||||||
|
|
||||||
encoded2 = encoder.encode_into(obj, preallocated)
|
encoded2 = encoder.encode_into(obj, preallocated)
|
||||||
|
|
||||||
assert len(encoded2) == 6
|
assert len(encoded2) == 8
|
||||||
assert encoded2[0] is preallocated
|
assert encoded2[0] is preallocated
|
||||||
|
|
||||||
decoded2: MyType = decoder.decode(encoded2)
|
decoded2: MyType = decoder.decode(encoded2)
|
||||||
@ -114,15 +118,15 @@ def test_multimodal_kwargs():
|
|||||||
|
|
||||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||||
|
|
||||||
# expected total encoding length, should be 44536, +-20 for minor changes
|
# expected total encoding length, should be 44559, +-20 for minor changes
|
||||||
assert total_len >= 44516 and total_len <= 44556
|
assert total_len >= 44539 and total_len <= 44579
|
||||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||||
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
||||||
|
|
||||||
|
|
||||||
def test_multimodal_items_by_modality():
|
def test_multimodal_items_by_modality():
|
||||||
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
|
e1 = MultiModalFieldElem("audio", "a0",
|
||||||
dtype=torch.int16),
|
torch.zeros(1000, dtype=torch.bfloat16),
|
||||||
MultiModalBatchedField())
|
MultiModalBatchedField())
|
||||||
e2 = MultiModalFieldElem(
|
e2 = MultiModalFieldElem(
|
||||||
"video",
|
"video",
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class MsgpackEncoder:
|
|||||||
|
|
||||||
def enc_hook(self, obj: Any) -> Any:
|
def enc_hook(self, obj: Any) -> Any:
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
return self._encode_ndarray(obj.numpy())
|
return self._encode_tensor(obj)
|
||||||
|
|
||||||
# Fall back to pickle for object or void kind ndarrays.
|
# Fall back to pickle for object or void kind ndarrays.
|
||||||
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
||||||
@ -133,9 +133,27 @@ class MsgpackEncoder:
|
|||||||
# backing buffers that we've stashed in `aux_buffers`.
|
# backing buffers that we've stashed in `aux_buffers`.
|
||||||
return obj.dtype.str, obj.shape, data
|
return obj.dtype.str, obj.shape, data
|
||||||
|
|
||||||
|
def _encode_tensor(
|
||||||
|
self, obj: torch.Tensor
|
||||||
|
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||||
|
assert self.aux_buffers is not None
|
||||||
|
# this creates a copy of the tensor if it's not already contiguous
|
||||||
|
obj = obj.contiguous()
|
||||||
|
# view the tensor as a 1D array of bytes
|
||||||
|
arr = obj.view((obj.numel(), )).view(torch.uint8).numpy()
|
||||||
|
if obj.nbytes < self.size_threshold:
|
||||||
|
# Smaller tensors are encoded inline, just like ndarrays.
|
||||||
|
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
|
||||||
|
else:
|
||||||
|
# Otherwise encode index of backing buffer to avoid copy.
|
||||||
|
data = len(self.aux_buffers)
|
||||||
|
self.aux_buffers.append(arr.data)
|
||||||
|
dtype = str(obj.dtype)[6:] # remove 'torch.' prefix
|
||||||
|
return dtype, obj.shape, data
|
||||||
|
|
||||||
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||||
if isinstance(nt, torch.Tensor):
|
if isinstance(nt, torch.Tensor):
|
||||||
return self._encode_ndarray(nt.numpy())
|
return self._encode_tensor(nt)
|
||||||
if isinstance(nt, (int, float)):
|
if isinstance(nt, (int, float)):
|
||||||
# Although it violates NestedTensors type, MultiModalKwargs
|
# Although it violates NestedTensors type, MultiModalKwargs
|
||||||
# values are sometimes floats.
|
# values are sometimes floats.
|
||||||
@ -186,7 +204,7 @@ class MsgpackDecoder:
|
|||||||
if issubclass(t, np.ndarray):
|
if issubclass(t, np.ndarray):
|
||||||
return self._decode_ndarray(obj)
|
return self._decode_ndarray(obj)
|
||||||
if issubclass(t, torch.Tensor):
|
if issubclass(t, torch.Tensor):
|
||||||
return torch.from_numpy(self._decode_ndarray(obj))
|
return self._decode_tensor(obj)
|
||||||
if issubclass(t, MultiModalKwargs):
|
if issubclass(t, MultiModalKwargs):
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return MultiModalKwargs.from_items(
|
return MultiModalKwargs.from_items(
|
||||||
@ -199,11 +217,24 @@ class MsgpackDecoder:
|
|||||||
|
|
||||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||||
dtype, shape, data = arr
|
dtype, shape, data = arr
|
||||||
# Copy from inline representation, otherwise Torch is unhappy since
|
# zero-copy decode. We assume the ndarray will not be kept around,
|
||||||
# the returned memory is non-writeable.
|
# as it now locks the whole received message buffer in memory.
|
||||||
|
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||||
|
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
||||||
|
|
||||||
|
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
||||||
|
dtype, shape, data = arr
|
||||||
|
# Copy from inline representation, to decouple the memory storage
|
||||||
|
# of the message from the original buffer. And also make Torch
|
||||||
|
# not complain about a readonly memoryview.
|
||||||
buffer = self.aux_buffers[data] if isinstance(data, int) \
|
buffer = self.aux_buffers[data] if isinstance(data, int) \
|
||||||
else bytearray(data)
|
else bytearray(data)
|
||||||
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
|
# Create numpy wrapper around the bytes
|
||||||
|
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
|
||||||
|
torch_dtype = getattr(torch, dtype)
|
||||||
|
assert isinstance(torch_dtype, torch.dtype)
|
||||||
|
# Convert back to proper shape & type
|
||||||
|
return torch.from_numpy(arr).view(torch_dtype).view(shape)
|
||||||
|
|
||||||
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
||||||
decoded_items = []
|
decoded_items = []
|
||||||
@ -228,7 +259,7 @@ class MsgpackDecoder:
|
|||||||
if not isinstance(obj, list):
|
if not isinstance(obj, list):
|
||||||
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
|
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
|
||||||
if obj and isinstance(obj[0], str):
|
if obj and isinstance(obj[0], str):
|
||||||
return torch.from_numpy(self._decode_ndarray(obj))
|
return self._decode_tensor(obj)
|
||||||
return [self._decode_nested_tensors(x) for x in obj]
|
return [self._decode_nested_tensors(x) for x in obj]
|
||||||
|
|
||||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user