diff --git a/nodes.py b/nodes.py index 270060f..a68baa1 100644 --- a/nodes.py +++ b/nodes.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import torchvision.utils as vutils import scipy.ndimage import numpy as np -from PIL import ImageColor, Image, ImageDraw, ImageFont +from PIL import ImageFilter, Image, ImageDraw, ImageFont from PIL.PngImagePlugin import PngInfo import json import re @@ -64,7 +64,7 @@ class CreateFluidMask: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "createfluidmask" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/generate" @classmethod def INPUT_TYPES(s): @@ -141,7 +141,7 @@ class CreateAudioMask: RETURN_TYPES = ("IMAGE",) FUNCTION = "createaudiomask" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/generate" @classmethod def INPUT_TYPES(s): @@ -194,7 +194,7 @@ class CreateGradientMask: RETURN_TYPES = ("MASK",) FUNCTION = "createmask" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/generate" @classmethod def INPUT_TYPES(s): @@ -229,7 +229,7 @@ class CreateFadeMask: RETURN_TYPES = ("MASK",) FUNCTION = "createfademask" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/generate" @classmethod def INPUT_TYPES(s): @@ -454,7 +454,7 @@ class CreateTextMask: RETURN_TYPES = ("IMAGE", "MASK",) FUNCTION = "createtextmask" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/generate" @classmethod def INPUT_TYPES(s): @@ -534,7 +534,7 @@ class GrowMaskWithBlur: }, } - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking" RETURN_TYPES = ("MASK", "MASK",) RETURN_NAMES = ("mask", "mask_inverted",) @@ -602,7 +602,7 @@ class ColorToMask: RETURN_TYPES = ("MASK",) FUNCTION = "clip" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking" @classmethod def INPUT_TYPES(s): @@ -659,7 +659,7 @@ class ConditioningMultiCombine: RETURN_TYPES = ("CONDITIONING", "INT") RETURN_NAMES = ("combined", "inputcount") FUNCTION = "combine" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/conditioning" def combine(self, inputcount, **kwargs): cond_combine_node = nodes.ConditioningCombine() @@ -697,7 +697,7 @@ class ConditioningSetMaskAndCombine: RETURN_TYPES = ("CONDITIONING","CONDITIONING",) RETURN_NAMES = ("combined_positive", "combined_negative",) FUNCTION = "append" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/conditioning" def append(self, positive_1, negative_1, positive_2, negative_2, mask_1, mask_2, set_cond_area, mask_1_strength, mask_2_strength): c = [] @@ -743,7 +743,7 @@ class ConditioningSetMaskAndCombine3: RETURN_TYPES = ("CONDITIONING","CONDITIONING",) RETURN_NAMES = ("combined_positive", "combined_negative",) FUNCTION = "append" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/conditioning" def append(self, positive_1, negative_1, positive_2, positive_3, negative_2, negative_3, mask_1, mask_2, mask_3, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength): c = [] @@ -799,7 +799,7 @@ class ConditioningSetMaskAndCombine4: RETURN_TYPES = ("CONDITIONING","CONDITIONING",) RETURN_NAMES = ("combined_positive", "combined_negative",) FUNCTION = "append" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/conditioning" def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, negative_2, negative_3, negative_4, mask_1, mask_2, mask_3, mask_4, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength): c = [] @@ -865,7 +865,7 @@ class ConditioningSetMaskAndCombine5: RETURN_TYPES = ("CONDITIONING","CONDITIONING",) RETURN_NAMES = ("combined_positive", "combined_negative",) FUNCTION = "append" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking/conditioning" def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, positive_5, negative_2, negative_3, negative_4, negative_5, mask_1, mask_2, mask_3, mask_4, mask_5, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength, mask_5_strength): c = [] @@ -1034,7 +1034,7 @@ class ColorMatch: }, } - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking" RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) @@ -1230,8 +1230,6 @@ class ImageBatchTestPattern: FUNCTION = "generatetestpattern" CATEGORY = "KJNodes" - - def generatetestpattern(self, batch_size, start_from, width, height): out = [] # Generate the sequential numbers for each image @@ -1270,7 +1268,176 @@ class ImageBatchTestPattern: out.append(image) return (torch.cat(out, dim=0),) - + +#based on nodes from mtb https://github.com/melMass/comfy_mtb + +from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor + +class BatchCropFromMask: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "original_images": ("IMAGE",), + "masks": ("MASK",), + "bbox_size": ("INT", {"default": 256, "min": 64, "max": 1024, "step": 8}), + }, + } + + RETURN_TYPES = ( + "IMAGE", + "IMAGE", + "BBOX", + ) + RETURN_NAMES = ( + "original_images", + "cropped_images", + "bboxes", + ) + FUNCTION = "crop" + CATEGORY = "KJNodes/masking" + + def crop(self, masks, original_images, bbox_size): + bounding_boxes = [] + cropped_images = [] + + for mask, img in zip(masks, original_images): + _mask = tensor2pil(mask)[0] + + # Calculate bounding box coordinates + non_zero_indices = np.nonzero(np.array(_mask)) + min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + + # Calculate center of bounding box + center_x = (max_x + min_x) // 2 + center_y = (max_y + min_y) // 2 + + # Create fixed-size bounding box around center + half_box_size = bbox_size // 2 + min_x = center_x - half_box_size + max_x = center_x + half_box_size + min_y = center_y - half_box_size + max_y = center_y + half_box_size + + # Check if the bounding box dimensions go outside the image dimensions + if min_x < 0: + max_x -= min_x + min_x = 0 + if max_x > img.shape[1]: + min_x -= max_x - img.shape[1] + max_x = img.shape[1] + if min_y < 0: + max_y -= min_y + min_y = 0 + if max_y > img.shape[0]: + min_y -= max_y - img.shape[0] + max_y = img.shape[0] + + # Append bounding box coordinates + bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y)) + + # Crop the image from the bounding box + cropped_img = img[min_y:max_y, min_x:max_x, :] + cropped_images.append(cropped_img) + cropped_out = torch.stack(cropped_images, dim=0) + + return (original_images, cropped_out, bounding_boxes,) + + +def bbox_to_region(bbox, target_size=None): + bbox = bbox_check(bbox, target_size) + return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) + +def bbox_check(bbox, target_size=None): + if not target_size: + return bbox + + new_bbox = ( + bbox[0], + bbox[1], + min(target_size[0] - bbox[0], bbox[2]), + min(target_size[1] - bbox[1], bbox[3]), + ) + return new_bbox + +class BatchUncrop: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "original_images": ("IMAGE",), + "cropped_images": ("IMAGE",), + "bboxes": ("BBOX",), + "border_blending": ( + "FLOAT", + {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, + ), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "uncrop" + + CATEGORY = "KJNodes/masking" + + def uncrop(self, original_images, cropped_images, bboxes, border_blending): + def inset_border(image, border_width=20, border_color=(0)): + width, height = image.size + bordered_image = Image.new(image.mode, (width, height), border_color) + bordered_image.paste(image, (0, 0)) + draw = ImageDraw.Draw(bordered_image) + draw.rectangle( + (0, 0, width - 1, height - 1), outline=border_color, width=border_width + ) + return bordered_image + + if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes): + raise ValueError("The number of images, crop_images, and bboxes should be the same") + + input_images = tensor2pil(original_images) + crop_imgs = tensor2pil(cropped_images) + out_images = [] + for i in range(len(input_images)): + img = input_images[i] + crop = crop_imgs[i] + bbox = bboxes[i] + + # uncrop the image based on the bounding box + bb_x, bb_y, bb_width, bb_height = bbox + + paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size) + + crop_img = crop.convert("RGB") + + if border_blending > 1.0: + border_blending = 1.0 + elif border_blending < 0.0: + border_blending = 0.0 + + blend_ratio = (max(crop_img.size) / 2) * float(border_blending) + + blend = img.convert("RGBA") + mask = Image.new("L", img.size, 0) + + mask_block = Image.new("L", (bb_width, bb_height), 255) + mask_block = inset_border(mask_block, int(blend_ratio / 2), (0)) + + mask.paste(mask_block, paste_region) + blend.paste(crop_img, paste_region) + + mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4)) + mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4)) + + blend.putalpha(mask) + img = Image.alpha_composite(img.convert("RGBA"), blend) + out_images.append(img.convert("RGB")) + + return (pil2tensor(out_images),) + + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -1298,7 +1465,9 @@ NODE_CLASS_MAPPINGS = { "ImageGridComposite3x3": ImageGridComposite3x3, "ImageConcanate": ImageConcanate, "ImageBatchTestPattern": ImageBatchTestPattern, - "ReplaceImagesInBatch": ReplaceImagesInBatch + "ReplaceImagesInBatch": ReplaceImagesInBatch, + "BatchCropFromMask": BatchCropFromMask, + "BatchUncrop": BatchUncrop, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -1326,5 +1495,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageGridComposite3x3": "ImageGridComposite3x3", "ImageConcanate": "ImageConcanate", "ImageBatchTestPattern": "ImageBatchTestPattern", - "ReplaceImagesInBatch": "ReplaceImagesInBatch" + "ReplaceImagesInBatch": "ReplaceImagesInBatch", + "BatchCropFromMask": "BatchCropFromMask", + "BatchUncrop": "BatchUncrop", } \ No newline at end of file diff --git a/utility.py b/utility.py new file mode 100644 index 0000000..f3b5c42 --- /dev/null +++ b/utility.py @@ -0,0 +1,39 @@ +import torch +import numpy as np +from PIL import Image +from typing import Union, List + +# Utility functions from mtb nodes: https://github.com/melMass/comfy_mtb +def pil2tensor(image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor: + if isinstance(image, list): + return torch.cat([pil2tensor(img) for img in image], dim=0) + + return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) + + +def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor: + if isinstance(img_np, list): + return torch.cat([np2tensor(img) for img in img_np], dim=0) + + return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0) + + +def tensor2np(tensor: torch.Tensor): + if len(tensor.shape) == 3: # Single image + return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) + else: # Batch of images + return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor] + +def tensor2pil(image: torch.Tensor) -> List[Image.Image]: + batch_count = image.size(0) if len(image.shape) > 3 else 1 + if batch_count > 1: + out = [] + for i in range(batch_count): + out.extend(tensor2pil(image[i])) + return out + + return [ + Image.fromarray( + np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + ) + ] \ No newline at end of file