mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-03-16 14:17:03 +08:00
Add batch uncrop/crop
Modfied nodes from mtb nodes
This commit is contained in:
parent
98bff811f3
commit
dbd2ea2406
207
nodes.py
207
nodes.py
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||||||
import torchvision.utils as vutils
|
import torchvision.utils as vutils
|
||||||
import scipy.ndimage
|
import scipy.ndimage
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import ImageColor, Image, ImageDraw, ImageFont
|
from PIL import ImageFilter, Image, ImageDraw, ImageFont
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@ -64,7 +64,7 @@ class CreateFluidMask:
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "createfluidmask"
|
FUNCTION = "createfluidmask"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/generate"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -141,7 +141,7 @@ class CreateAudioMask:
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "createaudiomask"
|
FUNCTION = "createaudiomask"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/generate"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -194,7 +194,7 @@ class CreateGradientMask:
|
|||||||
|
|
||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "createmask"
|
FUNCTION = "createmask"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/generate"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -229,7 +229,7 @@ class CreateFadeMask:
|
|||||||
|
|
||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "createfademask"
|
FUNCTION = "createfademask"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/generate"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -454,7 +454,7 @@ class CreateTextMask:
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK",)
|
RETURN_TYPES = ("IMAGE", "MASK",)
|
||||||
FUNCTION = "createtextmask"
|
FUNCTION = "createtextmask"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/generate"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -534,7 +534,7 @@ class GrowMaskWithBlur:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
RETURN_TYPES = ("MASK", "MASK",)
|
RETURN_TYPES = ("MASK", "MASK",)
|
||||||
RETURN_NAMES = ("mask", "mask_inverted",)
|
RETURN_NAMES = ("mask", "mask_inverted",)
|
||||||
@ -602,7 +602,7 @@ class ColorToMask:
|
|||||||
|
|
||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "clip"
|
FUNCTION = "clip"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -659,7 +659,7 @@ class ConditioningMultiCombine:
|
|||||||
RETURN_TYPES = ("CONDITIONING", "INT")
|
RETURN_TYPES = ("CONDITIONING", "INT")
|
||||||
RETURN_NAMES = ("combined", "inputcount")
|
RETURN_NAMES = ("combined", "inputcount")
|
||||||
FUNCTION = "combine"
|
FUNCTION = "combine"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/conditioning"
|
||||||
|
|
||||||
def combine(self, inputcount, **kwargs):
|
def combine(self, inputcount, **kwargs):
|
||||||
cond_combine_node = nodes.ConditioningCombine()
|
cond_combine_node = nodes.ConditioningCombine()
|
||||||
@ -697,7 +697,7 @@ class ConditioningSetMaskAndCombine:
|
|||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||||
FUNCTION = "append"
|
FUNCTION = "append"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/conditioning"
|
||||||
|
|
||||||
def append(self, positive_1, negative_1, positive_2, negative_2, mask_1, mask_2, set_cond_area, mask_1_strength, mask_2_strength):
|
def append(self, positive_1, negative_1, positive_2, negative_2, mask_1, mask_2, set_cond_area, mask_1_strength, mask_2_strength):
|
||||||
c = []
|
c = []
|
||||||
@ -743,7 +743,7 @@ class ConditioningSetMaskAndCombine3:
|
|||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||||
FUNCTION = "append"
|
FUNCTION = "append"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/conditioning"
|
||||||
|
|
||||||
def append(self, positive_1, negative_1, positive_2, positive_3, negative_2, negative_3, mask_1, mask_2, mask_3, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength):
|
def append(self, positive_1, negative_1, positive_2, positive_3, negative_2, negative_3, mask_1, mask_2, mask_3, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength):
|
||||||
c = []
|
c = []
|
||||||
@ -799,7 +799,7 @@ class ConditioningSetMaskAndCombine4:
|
|||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||||
FUNCTION = "append"
|
FUNCTION = "append"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/conditioning"
|
||||||
|
|
||||||
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, negative_2, negative_3, negative_4, mask_1, mask_2, mask_3, mask_4, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength):
|
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, negative_2, negative_3, negative_4, mask_1, mask_2, mask_3, mask_4, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength):
|
||||||
c = []
|
c = []
|
||||||
@ -865,7 +865,7 @@ class ConditioningSetMaskAndCombine5:
|
|||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||||
FUNCTION = "append"
|
FUNCTION = "append"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking/conditioning"
|
||||||
|
|
||||||
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, positive_5, negative_2, negative_3, negative_4, negative_5, mask_1, mask_2, mask_3, mask_4, mask_5, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength, mask_5_strength):
|
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, positive_5, negative_2, negative_3, negative_4, negative_5, mask_1, mask_2, mask_3, mask_4, mask_5, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength, mask_5_strength):
|
||||||
c = []
|
c = []
|
||||||
@ -1034,7 +1034,7 @@ class ColorMatch:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
RETURN_NAMES = ("image",)
|
RETURN_NAMES = ("image",)
|
||||||
@ -1230,8 +1230,6 @@ class ImageBatchTestPattern:
|
|||||||
FUNCTION = "generatetestpattern"
|
FUNCTION = "generatetestpattern"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generatetestpattern(self, batch_size, start_from, width, height):
|
def generatetestpattern(self, batch_size, start_from, width, height):
|
||||||
out = []
|
out = []
|
||||||
# Generate the sequential numbers for each image
|
# Generate the sequential numbers for each image
|
||||||
@ -1271,6 +1269,175 @@ class ImageBatchTestPattern:
|
|||||||
|
|
||||||
return (torch.cat(out, dim=0),)
|
return (torch.cat(out, dim=0),)
|
||||||
|
|
||||||
|
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
||||||
|
|
||||||
|
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
||||||
|
|
||||||
|
class BatchCropFromMask:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"original_images": ("IMAGE",),
|
||||||
|
"masks": ("MASK",),
|
||||||
|
"bbox_size": ("INT", {"default": 256, "min": 64, "max": 1024, "step": 8}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (
|
||||||
|
"IMAGE",
|
||||||
|
"IMAGE",
|
||||||
|
"BBOX",
|
||||||
|
)
|
||||||
|
RETURN_NAMES = (
|
||||||
|
"original_images",
|
||||||
|
"cropped_images",
|
||||||
|
"bboxes",
|
||||||
|
)
|
||||||
|
FUNCTION = "crop"
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
|
def crop(self, masks, original_images, bbox_size):
|
||||||
|
bounding_boxes = []
|
||||||
|
cropped_images = []
|
||||||
|
|
||||||
|
for mask, img in zip(masks, original_images):
|
||||||
|
_mask = tensor2pil(mask)[0]
|
||||||
|
|
||||||
|
# Calculate bounding box coordinates
|
||||||
|
non_zero_indices = np.nonzero(np.array(_mask))
|
||||||
|
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||||
|
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||||
|
|
||||||
|
# Calculate center of bounding box
|
||||||
|
center_x = (max_x + min_x) // 2
|
||||||
|
center_y = (max_y + min_y) // 2
|
||||||
|
|
||||||
|
# Create fixed-size bounding box around center
|
||||||
|
half_box_size = bbox_size // 2
|
||||||
|
min_x = center_x - half_box_size
|
||||||
|
max_x = center_x + half_box_size
|
||||||
|
min_y = center_y - half_box_size
|
||||||
|
max_y = center_y + half_box_size
|
||||||
|
|
||||||
|
# Check if the bounding box dimensions go outside the image dimensions
|
||||||
|
if min_x < 0:
|
||||||
|
max_x -= min_x
|
||||||
|
min_x = 0
|
||||||
|
if max_x > img.shape[1]:
|
||||||
|
min_x -= max_x - img.shape[1]
|
||||||
|
max_x = img.shape[1]
|
||||||
|
if min_y < 0:
|
||||||
|
max_y -= min_y
|
||||||
|
min_y = 0
|
||||||
|
if max_y > img.shape[0]:
|
||||||
|
min_y -= max_y - img.shape[0]
|
||||||
|
max_y = img.shape[0]
|
||||||
|
|
||||||
|
# Append bounding box coordinates
|
||||||
|
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
|
||||||
|
|
||||||
|
# Crop the image from the bounding box
|
||||||
|
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||||
|
cropped_images.append(cropped_img)
|
||||||
|
cropped_out = torch.stack(cropped_images, dim=0)
|
||||||
|
|
||||||
|
return (original_images, cropped_out, bounding_boxes,)
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_to_region(bbox, target_size=None):
|
||||||
|
bbox = bbox_check(bbox, target_size)
|
||||||
|
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
|
||||||
|
|
||||||
|
def bbox_check(bbox, target_size=None):
|
||||||
|
if not target_size:
|
||||||
|
return bbox
|
||||||
|
|
||||||
|
new_bbox = (
|
||||||
|
bbox[0],
|
||||||
|
bbox[1],
|
||||||
|
min(target_size[0] - bbox[0], bbox[2]),
|
||||||
|
min(target_size[1] - bbox[1], bbox[3]),
|
||||||
|
)
|
||||||
|
return new_bbox
|
||||||
|
|
||||||
|
class BatchUncrop:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"original_images": ("IMAGE",),
|
||||||
|
"cropped_images": ("IMAGE",),
|
||||||
|
"bboxes": ("BBOX",),
|
||||||
|
"border_blending": (
|
||||||
|
"FLOAT",
|
||||||
|
{"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "uncrop"
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
|
def uncrop(self, original_images, cropped_images, bboxes, border_blending):
|
||||||
|
def inset_border(image, border_width=20, border_color=(0)):
|
||||||
|
width, height = image.size
|
||||||
|
bordered_image = Image.new(image.mode, (width, height), border_color)
|
||||||
|
bordered_image.paste(image, (0, 0))
|
||||||
|
draw = ImageDraw.Draw(bordered_image)
|
||||||
|
draw.rectangle(
|
||||||
|
(0, 0, width - 1, height - 1), outline=border_color, width=border_width
|
||||||
|
)
|
||||||
|
return bordered_image
|
||||||
|
|
||||||
|
if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes):
|
||||||
|
raise ValueError("The number of images, crop_images, and bboxes should be the same")
|
||||||
|
|
||||||
|
input_images = tensor2pil(original_images)
|
||||||
|
crop_imgs = tensor2pil(cropped_images)
|
||||||
|
out_images = []
|
||||||
|
for i in range(len(input_images)):
|
||||||
|
img = input_images[i]
|
||||||
|
crop = crop_imgs[i]
|
||||||
|
bbox = bboxes[i]
|
||||||
|
|
||||||
|
# uncrop the image based on the bounding box
|
||||||
|
bb_x, bb_y, bb_width, bb_height = bbox
|
||||||
|
|
||||||
|
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||||
|
|
||||||
|
crop_img = crop.convert("RGB")
|
||||||
|
|
||||||
|
if border_blending > 1.0:
|
||||||
|
border_blending = 1.0
|
||||||
|
elif border_blending < 0.0:
|
||||||
|
border_blending = 0.0
|
||||||
|
|
||||||
|
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
|
||||||
|
|
||||||
|
blend = img.convert("RGBA")
|
||||||
|
mask = Image.new("L", img.size, 0)
|
||||||
|
|
||||||
|
mask_block = Image.new("L", (bb_width, bb_height), 255)
|
||||||
|
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0))
|
||||||
|
|
||||||
|
mask.paste(mask_block, paste_region)
|
||||||
|
blend.paste(crop_img, paste_region)
|
||||||
|
|
||||||
|
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
|
||||||
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
|
||||||
|
|
||||||
|
blend.putalpha(mask)
|
||||||
|
img = Image.alpha_composite(img.convert("RGBA"), blend)
|
||||||
|
out_images.append(img.convert("RGB"))
|
||||||
|
|
||||||
|
return (pil2tensor(out_images),)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
"FloatConstant": FloatConstant,
|
"FloatConstant": FloatConstant,
|
||||||
@ -1298,7 +1465,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ImageGridComposite3x3": ImageGridComposite3x3,
|
"ImageGridComposite3x3": ImageGridComposite3x3,
|
||||||
"ImageConcanate": ImageConcanate,
|
"ImageConcanate": ImageConcanate,
|
||||||
"ImageBatchTestPattern": ImageBatchTestPattern,
|
"ImageBatchTestPattern": ImageBatchTestPattern,
|
||||||
"ReplaceImagesInBatch": ReplaceImagesInBatch
|
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
||||||
|
"BatchCropFromMask": BatchCropFromMask,
|
||||||
|
"BatchUncrop": BatchUncrop,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -1326,5 +1495,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageGridComposite3x3": "ImageGridComposite3x3",
|
"ImageGridComposite3x3": "ImageGridComposite3x3",
|
||||||
"ImageConcanate": "ImageConcanate",
|
"ImageConcanate": "ImageConcanate",
|
||||||
"ImageBatchTestPattern": "ImageBatchTestPattern",
|
"ImageBatchTestPattern": "ImageBatchTestPattern",
|
||||||
"ReplaceImagesInBatch": "ReplaceImagesInBatch"
|
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
||||||
|
"BatchCropFromMask": "BatchCropFromMask",
|
||||||
|
"BatchUncrop": "BatchUncrop",
|
||||||
}
|
}
|
||||||
39
utility.py
Normal file
39
utility.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
# Utility functions from mtb nodes: https://github.com/melMass/comfy_mtb
|
||||||
|
def pil2tensor(image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:
|
||||||
|
if isinstance(image, list):
|
||||||
|
return torch.cat([pil2tensor(img) for img in image], dim=0)
|
||||||
|
|
||||||
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
|
||||||
|
if isinstance(img_np, list):
|
||||||
|
return torch.cat([np2tensor(img) for img in img_np], dim=0)
|
||||||
|
|
||||||
|
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2np(tensor: torch.Tensor):
|
||||||
|
if len(tensor.shape) == 3: # Single image
|
||||||
|
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
||||||
|
else: # Batch of images
|
||||||
|
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
|
||||||
|
|
||||||
|
def tensor2pil(image: torch.Tensor) -> List[Image.Image]:
|
||||||
|
batch_count = image.size(0) if len(image.shape) > 3 else 1
|
||||||
|
if batch_count > 1:
|
||||||
|
out = []
|
||||||
|
for i in range(batch_count):
|
||||||
|
out.extend(tensor2pil(image[i]))
|
||||||
|
return out
|
||||||
|
|
||||||
|
return [
|
||||||
|
Image.fromarray(
|
||||||
|
np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
|
||||||
|
)
|
||||||
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user