diff --git a/__init__.py b/__init__.py index 8880b88..5d3f45a 100644 --- a/__init__.py +++ b/__init__.py @@ -60,6 +60,7 @@ NODE_CONFIG = { "ImageConcanate": {"class": ImageConcanate, "name": "Image Concatenate"}, "ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"}, "ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"}, + "ImageCropByMask": {"class": ImageCropByMask, "name": "Image Crop By Mask"}, "ImageCropByMaskAndResize": {"class": ImageCropByMaskAndResize, "name": "Image Crop By Mask And Resize"}, "ImageCropByMaskBatch": {"class": ImageCropByMaskBatch, "name": "Image Crop By Mask Batch"}, "ImageUncropByMask": {"class": ImageUncropByMask, "name": "Image Uncrop By Mask"}, @@ -72,6 +73,7 @@ NODE_CONFIG = { "ImagePass": {"class": ImagePass}, "ImagePadForOutpaintMasked": {"class": ImagePadForOutpaintMasked, "name": "Image Pad For Outpaint Masked"}, "ImagePadForOutpaintTargetSize": {"class": ImagePadForOutpaintTargetSize, "name": "Image Pad For Outpaint Target Size"}, + "ImagePrepForICLora": {"class": ImagePrepForICLora, "name": "Image Prep For ICLora"}, "ImageResizeKJ": {"class": ImageResizeKJ, "name": "Resize Image"}, "ImageUpscaleWithModelBatched": {"class": ImageUpscaleWithModelBatched, "name": "Image Upscale With Model Batched"}, "InsertImagesToBatchIndexed": {"class": InsertImagesToBatchIndexed, "name": "Insert Images To Batch Indexed"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index d02d6f8..f93c81c 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1092,7 +1092,66 @@ class ImagePadForOutpaintTargetSize: # Now call the original expand_image with the calculated padding return ImagePadForOutpaintMasked.expand_image(self, image_scaled, pad_left, pad_top, pad_right, pad_bottom, feathering, mask_scaled) - + +class ImagePrepForICLora: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "output_width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), + "output_height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), + "border_width": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), + }, + "optional": { + "mask": ("MASK",), + } + } + + RETURN_TYPES = ("IMAGE", "MASK") + FUNCTION = "expand_image" + + CATEGORY = "image" + + def expand_image(self, image, output_width, output_height, border_width, mask=None): + if mask is not None: + if torch.allclose(mask, torch.zeros_like(mask)): + print("Warning: The incoming mask is fully black. Handling it as None.") + mask = None + B, H, W, C = image.size() + + # Handle mask + if mask is not None: + resized_mask = torch.nn.functional.interpolate( + mask.unsqueeze(1), + size=(image.shape[1], image.shape[2]), + mode='nearest' + ).squeeze(1) + print(resized_mask.shape) + image = image * resized_mask.unsqueeze(-1) + + # Calculate new width maintaining aspect ratio + new_width = int((W / H) * output_height) + + # Resize image to new height while maintaining aspect ratio + resized_image = common_upscale(image.movedim(-1,1), new_width, output_height, "lanczos", "disabled").movedim(1,-1) + + # Create padded image + empty_image = torch.zeros((B, output_height, output_width, C), device=image.device) + + if border_width > 0: + border = torch.zeros((B, output_height, border_width, C), device=image.device) + padded_image = torch.cat((resized_image, border, empty_image), dim=2) + padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) + padded_mask[:, :, :new_width + border_width] = 0 + else: + padded_image = torch.cat((resized_image, empty_image), dim=2) + padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) + padded_mask[:, :, :new_width] = 0 + + return (padded_image, padded_mask) + + class ImageAndMaskPreview(SaveImage): def __init__(self): self.output_dir = folder_paths.get_temp_directory() @@ -2789,6 +2848,48 @@ class ImageCropByMaskAndResize: return (torch.stack(image_list), torch.stack(mask_list), bbox_list) +class ImageCropByMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE", ), + "mask": ("MASK", ), + }, + } + + RETURN_TYPES = ("IMAGE", ) + RETURN_NAMES = ("image", ) + FUNCTION = "crop" + CATEGORY = "KJNodes/image" + + def crop(self, image, mask): + B, H, W, C = image.shape + mask = mask.round() + + # Find bounding box for each batch + crops = [] + + for b in range(B): + # Get coordinates of non-zero elements + rows = torch.any(mask[min(b, mask.shape[0]-1)] > 0, dim=1) + cols = torch.any(mask[min(b, mask.shape[0]-1)] > 0, dim=0) + + # Find boundaries + y_min, y_max = torch.where(rows)[0][[0, -1]] + x_min, x_max = torch.where(cols)[0][[0, -1]] + + # Crop image and mask + crop = image[b:b+1, y_min:y_max+1, x_min:x_max+1, :] + crops.append(crop) + + # Stack results back together + cropped_images = torch.cat(crops, dim=0) + + return (cropped_images, ) + + + class ImageUncropByMask: @classmethod