diff --git a/__init__.py b/__init__.py index c152380..285213b 100644 --- a/__init__.py +++ b/__init__.py @@ -56,6 +56,8 @@ NODE_CONFIG = { "ImageConcanate": {"class": ImageConcanate, "name": "Image Concatenate"}, "ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"}, "ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"}, + "ImageCropByMaskAndResize": {"class": ImageCropByMaskAndResize, "name": "Image Crop By Mask And Resize"}, + "ImageUncropByMask": {"class": ImageUncropByMask, "name": "Image Uncrop By Mask"}, "ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"}, "ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"}, "ImageGridComposite3x3": {"class": ImageGridComposite3x3, "name": "Image Grid Composite 3x3"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index b6b9784..3141720 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2325,4 +2325,160 @@ class FastPreview: return { "ui": {"bg_image": [img_base64]}, "result": () - } \ No newline at end of file + } + +class ImageCropByMaskAndResize: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE", ), + "mask": ("MASK", ), + "base_resolution": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }), + "padding": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), + }, + } + + RETURN_TYPES = ("IMAGE", "MASK", "BBOX", ) + RETURN_NAMES = ("images", "masks", "bbox",) + FUNCTION = "crop" + CATEGORY = "KJNodes/image" + + def crop_by_mask(self, mask, padding=0): + iy, ix = (mask == 1).nonzero(as_tuple=True) + h0, w0 = mask.shape + + if iy.numel() == 0: + x_c = w0 / 2.0 + y_c = h0 / 2.0 + width = 0 + height = 0 + else: + x_min = ix.min().item() + x_max = ix.max().item() + y_min = iy.min().item() + y_max = iy.max().item() + + width = x_max - x_min + height = y_max - y_min + + if width > w0 or height > h0: + raise Exception("Masked area out of bounds") + + x_c = (x_min + x_max) / 2.0 + y_c = (y_min + y_max) / 2.0 + + width_half = width / 2.0 + height_half = height / 2.0 + + if w0 <= width: + x0 = 0 + w = w0 + else: + x0 = max(0, x_c - width_half - padding) + w = width + 2 * padding + if x0 + w > w0: + x0 = w0 - w + + if h0 <= height: + y0 = 0 + h = h0 + else: + y0 = max(0, y_c - height_half - padding) + h = height + 2 * padding + if y0 + h > h0: + y0 = h0 - h + + return (int(x0), int(y0), int(w), int(h)) + + def crop(self, image, mask, base_resolution, padding=0): + image_list = [] + mask_list = [] + bbox_list = [] + for i in range(image.shape[0]): + x0, y0, w, h = self.crop_by_mask(mask[i], padding) + cropped_image = image[i][y0:y0+h, x0:x0+w, :] + cropped_mask = mask[i][y0:y0+h, x0:x0+w] + + cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1) # Move C to the second position (B, C, H, W) + + aspect_ratio = w / h + if aspect_ratio > 1: + target_width = base_resolution + target_height = int(base_resolution / aspect_ratio) + else: + target_height = base_resolution + target_width = int(base_resolution * aspect_ratio) + + cropped_image = F.interpolate(cropped_image, size=(target_height, target_width), mode='bilinear', align_corners=False) + cropped_image = cropped_image.movedim(1, -1).squeeze(0) + + cropped_mask = cropped_mask.unsqueeze(0).unsqueeze(0) + cropped_mask = F.interpolate(cropped_mask, size=(target_height, target_width), mode='nearest') + cropped_mask = cropped_mask.squeeze(0).squeeze(0) + + image_list.append(cropped_image) + mask_list.append(cropped_mask) + bbox_list.append((x0, y0, x0 + w, y0 + h)) + + return (torch.stack(image_list), torch.stack(mask_list), bbox_list) + +class ImageUncropByMask: + + @classmethod + def INPUT_TYPES(s): + return {"required": + { + "destination": ("IMAGE",), + "source": ("IMAGE",), + "mask": ("MASK",), + "bbox": ("BBOX",), + }, + } + + CATEGORY = "KJNodes/image" + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + FUNCTION = "uncrop" + + def uncrop(self, destination, source, mask, bbox=None): + + output_list = [] + + B, H, W, C = destination.shape + + for i in range(source.shape[0]): + x0, y0, x1, y1 = bbox[i] + bbox_height = y1 - y0 + bbox_width = x1 - x0 + + # Resize source image to match the bounding box dimensions + resized_source = F.interpolate(source[i].unsqueeze(0).movedim(-1, 1), size=(bbox_height, bbox_width), mode='bilinear', align_corners=False) + resized_source = resized_source.movedim(1, -1).squeeze(0) + + # Resize mask to match the bounding box dimensions + resized_mask = F.interpolate(mask[i].unsqueeze(0).unsqueeze(0), size=(bbox_height, bbox_width), mode='nearest') + resized_mask = resized_mask.squeeze(0).squeeze(0) + + # Calculate padding values + pad_left = x0 + pad_right = W - x1 + pad_top = y0 + pad_bottom = H - y1 + + # Pad the resized source image and mask to fit the destination dimensions + padded_source = F.pad(resized_source, pad=(0, 0, pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0) + padded_mask = F.pad(resized_mask, pad=(pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0) + + # Ensure the padded mask has the correct shape + padded_mask = padded_mask.unsqueeze(2).expand(-1, -1, destination[i].shape[2]) + # Ensure the padded source has the correct shape + padded_source = padded_source.unsqueeze(2).expand(-1, -1, -1, destination[i].shape[2]).squeeze(2) + + # Combine the destination and padded source images using the mask + result = destination[i] * (1.0 - padded_mask) + padded_source * padded_mask + + output_list.append(result) + + + return (torch.stack(output_list),) \ No newline at end of file