Add batch uncrop/crop

Modfied nodes from mtb nodes
This commit is contained in:
kijai 2023-11-08 23:23:19 +02:00
parent 98bff811f3
commit dbd2ea2406
2 changed files with 229 additions and 19 deletions

209
nodes.py
View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
import torchvision.utils as vutils
import scipy.ndimage
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont
from PIL import ImageFilter, Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
import json
import re
@ -64,7 +64,7 @@ class CreateFluidMask:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "createfluidmask"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
@ -141,7 +141,7 @@ class CreateAudioMask:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "createaudiomask"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
@ -194,7 +194,7 @@ class CreateGradientMask:
RETURN_TYPES = ("MASK",)
FUNCTION = "createmask"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
@ -229,7 +229,7 @@ class CreateFadeMask:
RETURN_TYPES = ("MASK",)
FUNCTION = "createfademask"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
@ -454,7 +454,7 @@ class CreateTextMask:
RETURN_TYPES = ("IMAGE", "MASK",)
FUNCTION = "createtextmask"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
@ -534,7 +534,7 @@ class GrowMaskWithBlur:
},
}
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking"
RETURN_TYPES = ("MASK", "MASK",)
RETURN_NAMES = ("mask", "mask_inverted",)
@ -602,7 +602,7 @@ class ColorToMask:
RETURN_TYPES = ("MASK",)
FUNCTION = "clip"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking"
@classmethod
def INPUT_TYPES(s):
@ -659,7 +659,7 @@ class ConditioningMultiCombine:
RETURN_TYPES = ("CONDITIONING", "INT")
RETURN_NAMES = ("combined", "inputcount")
FUNCTION = "combine"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/conditioning"
def combine(self, inputcount, **kwargs):
cond_combine_node = nodes.ConditioningCombine()
@ -697,7 +697,7 @@ class ConditioningSetMaskAndCombine:
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
RETURN_NAMES = ("combined_positive", "combined_negative",)
FUNCTION = "append"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/conditioning"
def append(self, positive_1, negative_1, positive_2, negative_2, mask_1, mask_2, set_cond_area, mask_1_strength, mask_2_strength):
c = []
@ -743,7 +743,7 @@ class ConditioningSetMaskAndCombine3:
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
RETURN_NAMES = ("combined_positive", "combined_negative",)
FUNCTION = "append"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/conditioning"
def append(self, positive_1, negative_1, positive_2, positive_3, negative_2, negative_3, mask_1, mask_2, mask_3, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength):
c = []
@ -799,7 +799,7 @@ class ConditioningSetMaskAndCombine4:
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
RETURN_NAMES = ("combined_positive", "combined_negative",)
FUNCTION = "append"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/conditioning"
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, negative_2, negative_3, negative_4, mask_1, mask_2, mask_3, mask_4, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength):
c = []
@ -865,7 +865,7 @@ class ConditioningSetMaskAndCombine5:
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
RETURN_NAMES = ("combined_positive", "combined_negative",)
FUNCTION = "append"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking/conditioning"
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, positive_5, negative_2, negative_3, negative_4, negative_5, mask_1, mask_2, mask_3, mask_4, mask_5, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength, mask_5_strength):
c = []
@ -1034,7 +1034,7 @@ class ColorMatch:
},
}
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
@ -1230,8 +1230,6 @@ class ImageBatchTestPattern:
FUNCTION = "generatetestpattern"
CATEGORY = "KJNodes"
def generatetestpattern(self, batch_size, start_from, width, height):
out = []
# Generate the sequential numbers for each image
@ -1270,7 +1268,176 @@ class ImageBatchTestPattern:
out.append(image)
return (torch.cat(out, dim=0),)
#based on nodes from mtb https://github.com/melMass/comfy_mtb
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
class BatchCropFromMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"masks": ("MASK",),
"bbox_size": ("INT", {"default": 256, "min": 64, "max": 1024, "step": 8}),
},
}
RETURN_TYPES = (
"IMAGE",
"IMAGE",
"BBOX",
)
RETURN_NAMES = (
"original_images",
"cropped_images",
"bboxes",
)
FUNCTION = "crop"
CATEGORY = "KJNodes/masking"
def crop(self, masks, original_images, bbox_size):
bounding_boxes = []
cropped_images = []
for mask, img in zip(masks, original_images):
_mask = tensor2pil(mask)[0]
# Calculate bounding box coordinates
non_zero_indices = np.nonzero(np.array(_mask))
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
# Calculate center of bounding box
center_x = (max_x + min_x) // 2
center_y = (max_y + min_y) // 2
# Create fixed-size bounding box around center
half_box_size = bbox_size // 2
min_x = center_x - half_box_size
max_x = center_x + half_box_size
min_y = center_y - half_box_size
max_y = center_y + half_box_size
# Check if the bounding box dimensions go outside the image dimensions
if min_x < 0:
max_x -= min_x
min_x = 0
if max_x > img.shape[1]:
min_x -= max_x - img.shape[1]
max_x = img.shape[1]
if min_y < 0:
max_y -= min_y
min_y = 0
if max_y > img.shape[0]:
min_y -= max_y - img.shape[0]
max_y = img.shape[0]
# Append bounding box coordinates
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
cropped_images.append(cropped_img)
cropped_out = torch.stack(cropped_images, dim=0)
return (original_images, cropped_out, bounding_boxes,)
def bbox_to_region(bbox, target_size=None):
bbox = bbox_check(bbox, target_size)
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
def bbox_check(bbox, target_size=None):
if not target_size:
return bbox
new_bbox = (
bbox[0],
bbox[1],
min(target_size[0] - bbox[0], bbox[2]),
min(target_size[1] - bbox[1], bbox[3]),
)
return new_bbox
class BatchUncrop:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"cropped_images": ("IMAGE",),
"bboxes": ("BBOX",),
"border_blending": (
"FLOAT",
{"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01},
),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "uncrop"
CATEGORY = "KJNodes/masking"
def uncrop(self, original_images, cropped_images, bboxes, border_blending):
def inset_border(image, border_width=20, border_color=(0)):
width, height = image.size
bordered_image = Image.new(image.mode, (width, height), border_color)
bordered_image.paste(image, (0, 0))
draw = ImageDraw.Draw(bordered_image)
draw.rectangle(
(0, 0, width - 1, height - 1), outline=border_color, width=border_width
)
return bordered_image
if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes):
raise ValueError("The number of images, crop_images, and bboxes should be the same")
input_images = tensor2pil(original_images)
crop_imgs = tensor2pil(cropped_images)
out_images = []
for i in range(len(input_images)):
img = input_images[i]
crop = crop_imgs[i]
bbox = bboxes[i]
# uncrop the image based on the bounding box
bb_x, bb_y, bb_width, bb_height = bbox
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
crop_img = crop.convert("RGB")
if border_blending > 1.0:
border_blending = 1.0
elif border_blending < 0.0:
border_blending = 0.0
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
blend = img.convert("RGBA")
mask = Image.new("L", img.size, 0)
mask_block = Image.new("L", (bb_width, bb_height), 255)
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0))
mask.paste(mask_block, paste_region)
blend.paste(crop_img, paste_region)
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
blend.putalpha(mask)
img = Image.alpha_composite(img.convert("RGBA"), blend)
out_images.append(img.convert("RGB"))
return (pil2tensor(out_images),)
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
"FloatConstant": FloatConstant,
@ -1298,7 +1465,9 @@ NODE_CLASS_MAPPINGS = {
"ImageGridComposite3x3": ImageGridComposite3x3,
"ImageConcanate": ImageConcanate,
"ImageBatchTestPattern": ImageBatchTestPattern,
"ReplaceImagesInBatch": ReplaceImagesInBatch
"ReplaceImagesInBatch": ReplaceImagesInBatch,
"BatchCropFromMask": BatchCropFromMask,
"BatchUncrop": BatchUncrop,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -1326,5 +1495,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageGridComposite3x3": "ImageGridComposite3x3",
"ImageConcanate": "ImageConcanate",
"ImageBatchTestPattern": "ImageBatchTestPattern",
"ReplaceImagesInBatch": "ReplaceImagesInBatch"
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
"BatchCropFromMask": "BatchCropFromMask",
"BatchUncrop": "BatchUncrop",
}

39
utility.py Normal file
View File

@ -0,0 +1,39 @@
import torch
import numpy as np
from PIL import Image
from typing import Union, List
# Utility functions from mtb nodes: https://github.com/melMass/comfy_mtb
def pil2tensor(image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:
if isinstance(image, list):
return torch.cat([pil2tensor(img) for img in image], dim=0)
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
if isinstance(img_np, list):
return torch.cat([np2tensor(img) for img in img_np], dim=0)
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
def tensor2np(tensor: torch.Tensor):
if len(tensor.shape) == 3: # Single image
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
else: # Batch of images
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
def tensor2pil(image: torch.Tensor) -> List[Image.Image]:
batch_count = image.size(0) if len(image.shape) > 3 else 1
if batch_count > 1:
out = []
for i in range(batch_count):
out.extend(tensor2pil(image[i]))
return out
return [
Image.fromarray(
np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
)
]