fixup ImageCropByMaskAndResize for batches more

This commit is contained in:
kijai 2024-11-06 14:21:57 +02:00
parent 3f903091b3
commit a982a31956

View File

@ -2474,21 +2474,52 @@ class ImageCropByMaskAndResize:
return (int(x0), int(y0), int(w), int(h))
def crop(self, image, mask, base_resolution, padding=0, min_crop_resolution=128, max_crop_resolution=512):
mask = mask.round()
image_list = []
mask_list = []
bbox_list = []
# First, collect all bounding boxes
bbox_params = []
aspect_ratios = []
for i in range(image.shape[0]):
x0, y0, w, h = self.crop_by_mask(mask[i], padding, min_crop_resolution, max_crop_resolution)
cropped_image = image[i][y0:y0+h, x0:x0+w, :]
cropped_mask = mask[i][y0:y0+h, x0:x0+w]
aspect_ratio = w / h
if aspect_ratio > 1:
target_width = base_resolution
target_height = int(base_resolution / aspect_ratio)
else:
target_height = base_resolution
target_width = int(base_resolution * aspect_ratio)
bbox_params.append((x0, y0, w, h))
aspect_ratios.append(w / h)
#print(bbox_params)
# Find maximum width and height
max_w = max([w for x0, y0, w, h in bbox_params])
max_h = max([h for x0, y0, w, h in bbox_params])
max_aspect_ratio = max(aspect_ratios)
# Ensure dimensions are divisible by 8
max_w = (max_w + 7) // 8 * 8
max_h = (max_h + 7) // 8 * 8
# Calculate common target dimensions
if max_aspect_ratio > 1:
target_width = base_resolution
target_height = int(base_resolution / max_aspect_ratio)
else:
target_height = base_resolution
target_width = int(base_resolution * max_aspect_ratio)
for i in range(image.shape[0]):
x0, y0, w, h = bbox_params[i]
# Adjust cropping to use maximum width and height
x_center = x0 + w / 2
y_center = y0 + h / 2
x0_new = int(max(0, x_center - max_w / 2))
y0_new = int(max(0, y_center - max_h / 2))
x1_new = int(min(x0_new + max_w, image.shape[2]))
y1_new = int(min(y0_new + max_h, image.shape[1]))
x0_new = x1_new - max_w
y0_new = y1_new - max_h
cropped_image = image[i][y0_new:y1_new, x0_new:x1_new, :]
cropped_mask = mask[i][y0_new:y1_new, x0_new:x1_new]
# Ensure dimensions are divisible by 8
target_width = (target_width + 7) // 8 * 8
@ -2504,7 +2535,8 @@ class ImageCropByMaskAndResize:
image_list.append(cropped_image)
mask_list.append(cropped_mask)
bbox_list.append((x0, y0, x0 + w, y0 + h))
bbox_list.append((x0_new, y0_new, x1_new, y1_new))
return (torch.stack(image_list), torch.stack(mask_list), bbox_list)