This commit is contained in:
Netanel Haber 2025-12-11 18:04:01 +02:00 committed by Netanel Haber
parent 4e558858b8
commit 7b55a619e4
2 changed files with 8 additions and 1836 deletions

File diff suppressed because it is too large Load Diff

View File

@ -113,7 +113,7 @@ pixel_statistics = {
@dataclass @dataclass
class DynamicResolutionParams: class DynamicResolutionParams:
image: Image.Image media: Image.Image
num_tiles: int num_tiles: int
num_embeddings: int num_embeddings: int
patch_size: tuple[int, int] patch_size: tuple[int, int]
@ -165,7 +165,7 @@ class DynamicResolutionImageTilingStrategy:
self, params: DynamicResolutionParams, **kwargs self, params: DynamicResolutionParams, **kwargs
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
# resize the image # resize the image
resized_img = params.image.resize( resized_img = params.media.resize(
( (
params.patch_size[0] * self._patch_size, params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size, params.patch_size[1] * self._patch_size,
@ -183,7 +183,7 @@ class DynamicResolutionImageTilingStrategy:
# Only add thumbnail if resized image area is less than threshold % of # Only add thumbnail if resized image area is less than threshold % of
# thumbnail area # thumbnail area
if area_ratio < self._thumbnail_area_threshold: if area_ratio < self._thumbnail_area_threshold:
thumbnail_img = params.image.resize( thumbnail_img = params.media.resize(
(self._thumbnail_size, self._thumbnail_size) (self._thumbnail_size, self._thumbnail_size)
) )
processed_images.append(thumbnail_img) processed_images.append(thumbnail_img)
@ -192,7 +192,7 @@ class DynamicResolutionImageTilingStrategy:
def process_media( def process_media(
self, self,
image: Image.Image, media: Image.Image,
num_tokens_available: int, num_tokens_available: int,
data_augment: bool = False, data_augment: bool = False,
tiling_augment_prob: float = 0.4, tiling_augment_prob: float = 0.4,
@ -207,10 +207,10 @@ class DynamicResolutionImageTilingStrategy:
DynamicResolutionParams for the media DynamicResolutionParams for the media
""" """
current_num_tokens_available = num_tokens_available current_num_tokens_available = num_tokens_available
assert isinstance(image, Image.Image), ( assert isinstance(media, Image.Image), (
"Dynamic resolution is only supported for image media" "Dynamic resolution is only supported for image media"
) )
orig_width, orig_height = image.width, image.height orig_width, orig_height = media.width, media.height
closest_patch_height = round(orig_height / self._patch_size + 0.5) closest_patch_height = round(orig_height / self._patch_size + 0.5)
closest_patch_width = round(orig_width / self._patch_size + 0.5) closest_patch_width = round(orig_width / self._patch_size + 0.5)
@ -336,7 +336,7 @@ class DynamicResolutionImageTilingStrategy:
target_patch_width, target_patch_height, current_num_tokens_available target_patch_width, target_patch_height, current_num_tokens_available
) )
assert isinstance(image, Image.Image), ( assert isinstance(media, Image.Image), (
"Dynamic resolution is only supported for image media" "Dynamic resolution is only supported for image media"
) )
@ -374,7 +374,7 @@ class DynamicResolutionImageTilingStrategy:
) )
return DynamicResolutionParams( return DynamicResolutionParams(
image=image, media=media,
num_tiles=num_tiles, num_tiles=num_tiles,
num_embeddings=num_embeddings, num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height), patch_size=(target_patch_width, target_patch_height),