[Core][MultiModalHasher] Don't convert memoryviews to bytes during hashing (#24925)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-09-16 16:32:47 +01:00 committed by GitHub
parent 73cfb3c5ee
commit 08369289af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,22 +20,22 @@ logger = init_logger(__name__)
class MultiModalHasher: class MultiModalHasher:
@classmethod @classmethod
def serialize_item(cls, obj: object) -> Union[bytes, memoryview]: def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]:
# Simple cases # Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, (bytes, memoryview)): if isinstance(obj, (bytes, memoryview)):
return obj return (obj, )
if isinstance(obj, str):
return (obj.encode("utf-8"), )
if isinstance(obj, (int, float)): if isinstance(obj, (int, float)):
return np.array(obj).tobytes() return (np.array(obj).tobytes(), )
if isinstance(obj, Image.Image): if isinstance(obj, Image.Image):
exif = obj.getexif() exif = obj.getexif()
if Image.ExifTags.Base.ImageID in exif and isinstance( if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID): exif[Image.ExifTags.Base.ImageID], uuid.UUID):
# If the image has exif ImageID tag, use that # If the image has exif ImageID tag, use that
return exif[Image.ExifTags.Base.ImageID].bytes return (exif[Image.ExifTags.Base.ImageID].bytes, )
return cls.item_to_bytes( return cls.iter_item_to_bytes(
"image", np.asarray(convert_image_mode(obj, "RGBA"))) "image", np.asarray(convert_image_mode(obj, "RGBA")))
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu() tensor_obj: torch.Tensor = obj.cpu()
@ -49,43 +49,34 @@ class MultiModalHasher:
tensor_obj = tensor_obj.view( tensor_obj = tensor_obj.view(
(tensor_obj.numel(), )).view(torch.uint8) (tensor_obj.numel(), )).view(torch.uint8)
return cls.item_to_bytes( return cls.iter_item_to_bytes(
"tensor", { "tensor", {
"original_dtype": str(tensor_dtype), "original_dtype": str(tensor_dtype),
"original_shape": tuple(tensor_shape), "original_shape": tuple(tensor_shape),
"data": tensor_obj.numpy(), "data": tensor_obj.numpy(),
}) })
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
return cls.item_to_bytes("tensor", tensor_obj.numpy())
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):
# If the array is non-contiguous, we need to copy it first # If the array is non-contiguous, we need to copy it first
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() arr_data = obj.view(
return cls.item_to_bytes("ndarray", { np.uint8).data if obj.flags.c_contiguous else obj.tobytes()
return cls.iter_item_to_bytes("ndarray", {
"dtype": obj.dtype.str, "dtype": obj.dtype.str,
"shape": obj.shape, "shape": obj.shape,
"data": arr_data, "data": arr_data,
}) })
logger.warning( logger.warning(
"No serialization method found for %s. " "No serialization method found for %s. "
"Falling back to pickle.", type(obj)) "Falling back to pickle.", type(obj))
return pickle.dumps(obj) return (pickle.dumps(obj), )
@classmethod
def item_to_bytes(
cls,
key: str,
obj: object,
) -> bytes:
return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj))
@classmethod @classmethod
def iter_item_to_bytes( def iter_item_to_bytes(
cls, cls,
key: str, key: str,
obj: object, obj: object,
) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]: ) -> Iterable[Union[bytes, memoryview]]:
# Recursive cases # Recursive cases
if isinstance(obj, (list, tuple)): if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj): for i, elem in enumerate(obj):
@ -94,17 +85,15 @@ class MultiModalHasher:
for k, v in obj.items(): for k, v in obj.items():
yield from cls.iter_item_to_bytes(f"{key}.{k}", v) yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
else: else:
key_bytes = key.encode("utf-8") yield key.encode("utf-8")
value_bytes = cls.serialize_item(obj) yield from cls.serialize_item(obj)
yield key_bytes, value_bytes
@classmethod @classmethod
def hash_kwargs(cls, **kwargs: object) -> str: def hash_kwargs(cls, **kwargs: object) -> str:
hasher = blake3() hasher = blake3()
for k, v in kwargs.items(): for k, v in kwargs.items():
for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v): for bytes_ in cls.iter_item_to_bytes(k, v):
hasher.update(k_bytes) hasher.update(bytes_)
hasher.update(v_bytes)
return hasher.hexdigest() return hasher.hexdigest()