diff --git a/nodes.py b/nodes.py index a7f9a21..599b306 100644 --- a/nodes.py +++ b/nodes.py @@ -1281,7 +1281,7 @@ class BatchCropFromMask: "required": { "original_images": ("IMAGE",), "masks": ("MASK",), - "bbox_size": ("INT", {"default": 256, "min": 64, "max": 1024, "step": 8}), + "crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }, } @@ -1302,14 +1302,31 @@ class BatchCropFromMask: FUNCTION = "crop" CATEGORY = "KJNodes/masking" - def crop(self, masks, original_images, bbox_size): + def crop(self, masks, original_images, crop_size_mult): bounding_boxes = [] cropped_images = [] + # First, calculate the maximum bounding box size across all masks + max_bbox_size = 0 + for mask in masks: + _mask = tensor2pil(mask)[0] + non_zero_indices = np.nonzero(np.array(_mask)) + min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + width = max_x - min_x + height = max_y - min_y + bbox_size = max(width, height) + max_bbox_size = max(max_bbox_size, bbox_size) + + # Make sure max_bbox_size is divisible by 32, if not, round it upwards so it is + max_bbox_size = math.ceil(max_bbox_size / 32) * 32 + + # Apply the crop size multiplier + max_bbox_size = int(max_bbox_size * crop_size_mult) + + # Then, for each mask and corresponding image... for mask, img in zip(masks, original_images): _mask = tensor2pil(mask)[0] - - # Calculate bounding box coordinates non_zero_indices = np.nonzero(np.array(_mask)) min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) @@ -1318,36 +1335,23 @@ class BatchCropFromMask: center_x = (max_x + min_x) // 2 center_y = (max_y + min_y) // 2 - # Create fixed-size bounding box around center - half_box_size = bbox_size // 2 - min_x = center_x - half_box_size - max_x = center_x + half_box_size - min_y = center_y - half_box_size - max_y = center_y + half_box_size - - # Check if the bounding box dimensions go outside the image dimensions - if min_x < 0: - max_x -= min_x - min_x = 0 - if max_x > img.shape[1]: - min_x -= max_x - img.shape[1] - max_x = img.shape[1] - if min_y < 0: - max_y -= min_y - min_y = 0 - if max_y > img.shape[0]: - min_y -= max_y - img.shape[0] - max_y = img.shape[0] + # Create bounding box using max_bbox_size + half_box_size = max_bbox_size // 2 + min_x = max(0, center_x - half_box_size) + max_x = min(img.shape[1], center_x + half_box_size) + min_y = max(0, center_y - half_box_size) + max_y = min(img.shape[0], center_y + half_box_size) # Append bounding box coordinates bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y)) - + # Crop the image from the bounding box cropped_img = img[min_y:max_y, min_x:max_x, :] cropped_images.append(cropped_img) + cropped_out = torch.stack(cropped_images, dim=0) - - return (original_images, cropped_out, bounding_boxes, bbox_size, bbox_size,) + + return (original_images, cropped_out, bounding_boxes, max_bbox_size, max_bbox_size, ) def bbox_to_region(bbox, target_size=None):