diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 0a6869d..865d777 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2474,21 +2474,52 @@ class ImageCropByMaskAndResize: return (int(x0), int(y0), int(w), int(h)) def crop(self, image, mask, base_resolution, padding=0, min_crop_resolution=128, max_crop_resolution=512): + mask = mask.round() image_list = [] mask_list = [] bbox_list = [] + + # First, collect all bounding boxes + bbox_params = [] + aspect_ratios = [] for i in range(image.shape[0]): x0, y0, w, h = self.crop_by_mask(mask[i], padding, min_crop_resolution, max_crop_resolution) - cropped_image = image[i][y0:y0+h, x0:x0+w, :] - cropped_mask = mask[i][y0:y0+h, x0:x0+w] - - aspect_ratio = w / h - if aspect_ratio > 1: - target_width = base_resolution - target_height = int(base_resolution / aspect_ratio) - else: - target_height = base_resolution - target_width = int(base_resolution * aspect_ratio) + bbox_params.append((x0, y0, w, h)) + aspect_ratios.append(w / h) + #print(bbox_params) + + # Find maximum width and height + max_w = max([w for x0, y0, w, h in bbox_params]) + max_h = max([h for x0, y0, w, h in bbox_params]) + max_aspect_ratio = max(aspect_ratios) + + # Ensure dimensions are divisible by 8 + max_w = (max_w + 7) // 8 * 8 + max_h = (max_h + 7) // 8 * 8 + # Calculate common target dimensions + if max_aspect_ratio > 1: + target_width = base_resolution + target_height = int(base_resolution / max_aspect_ratio) + else: + target_height = base_resolution + target_width = int(base_resolution * max_aspect_ratio) + + for i in range(image.shape[0]): + x0, y0, w, h = bbox_params[i] + + # Adjust cropping to use maximum width and height + x_center = x0 + w / 2 + y_center = y0 + h / 2 + + x0_new = int(max(0, x_center - max_w / 2)) + y0_new = int(max(0, y_center - max_h / 2)) + x1_new = int(min(x0_new + max_w, image.shape[2])) + y1_new = int(min(y0_new + max_h, image.shape[1])) + x0_new = x1_new - max_w + y0_new = y1_new - max_h + + cropped_image = image[i][y0_new:y1_new, x0_new:x1_new, :] + cropped_mask = mask[i][y0_new:y1_new, x0_new:x1_new] # Ensure dimensions are divisible by 8 target_width = (target_width + 7) // 8 * 8 @@ -2504,7 +2535,8 @@ class ImageCropByMaskAndResize: image_list.append(cropped_image) mask_list.append(cropped_mask) - bbox_list.append((x0, y0, x0 + w, y0 + h)) + bbox_list.append((x0_new, y0_new, x1_new, y1_new)) + return (torch.stack(image_list), torch.stack(mask_list), bbox_list)