mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 02:05:01 +08:00
[Core][MultiModalHasher] Don't convert memoryviews to bytes during hashing (#24925)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
73cfb3c5ee
commit
08369289af
@ -20,22 +20,22 @@ logger = init_logger(__name__)
|
|||||||
class MultiModalHasher:
|
class MultiModalHasher:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
|
def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]:
|
||||||
# Simple cases
|
# Simple cases
|
||||||
if isinstance(obj, str):
|
|
||||||
return obj.encode("utf-8")
|
|
||||||
if isinstance(obj, (bytes, memoryview)):
|
if isinstance(obj, (bytes, memoryview)):
|
||||||
return obj
|
return (obj, )
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return (obj.encode("utf-8"), )
|
||||||
if isinstance(obj, (int, float)):
|
if isinstance(obj, (int, float)):
|
||||||
return np.array(obj).tobytes()
|
return (np.array(obj).tobytes(), )
|
||||||
|
|
||||||
if isinstance(obj, Image.Image):
|
if isinstance(obj, Image.Image):
|
||||||
exif = obj.getexif()
|
exif = obj.getexif()
|
||||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID):
|
exif[Image.ExifTags.Base.ImageID], uuid.UUID):
|
||||||
# If the image has exif ImageID tag, use that
|
# If the image has exif ImageID tag, use that
|
||||||
return exif[Image.ExifTags.Base.ImageID].bytes
|
return (exif[Image.ExifTags.Base.ImageID].bytes, )
|
||||||
return cls.item_to_bytes(
|
return cls.iter_item_to_bytes(
|
||||||
"image", np.asarray(convert_image_mode(obj, "RGBA")))
|
"image", np.asarray(convert_image_mode(obj, "RGBA")))
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
tensor_obj: torch.Tensor = obj.cpu()
|
tensor_obj: torch.Tensor = obj.cpu()
|
||||||
@ -49,43 +49,34 @@ class MultiModalHasher:
|
|||||||
tensor_obj = tensor_obj.view(
|
tensor_obj = tensor_obj.view(
|
||||||
(tensor_obj.numel(), )).view(torch.uint8)
|
(tensor_obj.numel(), )).view(torch.uint8)
|
||||||
|
|
||||||
return cls.item_to_bytes(
|
return cls.iter_item_to_bytes(
|
||||||
"tensor", {
|
"tensor", {
|
||||||
"original_dtype": str(tensor_dtype),
|
"original_dtype": str(tensor_dtype),
|
||||||
"original_shape": tuple(tensor_shape),
|
"original_shape": tuple(tensor_shape),
|
||||||
"data": tensor_obj.numpy(),
|
"data": tensor_obj.numpy(),
|
||||||
})
|
})
|
||||||
|
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
|
||||||
return cls.item_to_bytes("tensor", tensor_obj.numpy())
|
|
||||||
if isinstance(obj, np.ndarray):
|
if isinstance(obj, np.ndarray):
|
||||||
# If the array is non-contiguous, we need to copy it first
|
# If the array is non-contiguous, we need to copy it first
|
||||||
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
|
arr_data = obj.view(
|
||||||
return cls.item_to_bytes("ndarray", {
|
np.uint8).data if obj.flags.c_contiguous else obj.tobytes()
|
||||||
|
return cls.iter_item_to_bytes("ndarray", {
|
||||||
"dtype": obj.dtype.str,
|
"dtype": obj.dtype.str,
|
||||||
"shape": obj.shape,
|
"shape": obj.shape,
|
||||||
"data": arr_data,
|
"data": arr_data,
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No serialization method found for %s. "
|
"No serialization method found for %s. "
|
||||||
"Falling back to pickle.", type(obj))
|
"Falling back to pickle.", type(obj))
|
||||||
|
|
||||||
return pickle.dumps(obj)
|
return (pickle.dumps(obj), )
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def item_to_bytes(
|
|
||||||
cls,
|
|
||||||
key: str,
|
|
||||||
obj: object,
|
|
||||||
) -> bytes:
|
|
||||||
return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def iter_item_to_bytes(
|
def iter_item_to_bytes(
|
||||||
cls,
|
cls,
|
||||||
key: str,
|
key: str,
|
||||||
obj: object,
|
obj: object,
|
||||||
) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]:
|
) -> Iterable[Union[bytes, memoryview]]:
|
||||||
# Recursive cases
|
# Recursive cases
|
||||||
if isinstance(obj, (list, tuple)):
|
if isinstance(obj, (list, tuple)):
|
||||||
for i, elem in enumerate(obj):
|
for i, elem in enumerate(obj):
|
||||||
@ -94,17 +85,15 @@ class MultiModalHasher:
|
|||||||
for k, v in obj.items():
|
for k, v in obj.items():
|
||||||
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
||||||
else:
|
else:
|
||||||
key_bytes = key.encode("utf-8")
|
yield key.encode("utf-8")
|
||||||
value_bytes = cls.serialize_item(obj)
|
yield from cls.serialize_item(obj)
|
||||||
yield key_bytes, value_bytes
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def hash_kwargs(cls, **kwargs: object) -> str:
|
def hash_kwargs(cls, **kwargs: object) -> str:
|
||||||
hasher = blake3()
|
hasher = blake3()
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v):
|
for bytes_ in cls.iter_item_to_bytes(k, v):
|
||||||
hasher.update(k_bytes)
|
hasher.update(bytes_)
|
||||||
hasher.update(v_bytes)
|
|
||||||
|
|
||||||
return hasher.hexdigest()
|
return hasher.hexdigest()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user