mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Add batch uncrop/crop
Modfied nodes from mtb nodes
This commit is contained in:
parent
98bff811f3
commit
dbd2ea2406
209
nodes.py
209
nodes.py
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||
import torchvision.utils as vutils
|
||||
import scipy.ndimage
|
||||
import numpy as np
|
||||
from PIL import ImageColor, Image, ImageDraw, ImageFont
|
||||
from PIL import ImageFilter, Image, ImageDraw, ImageFont
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import json
|
||||
import re
|
||||
@ -64,7 +64,7 @@ class CreateFluidMask:
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "createfluidmask"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/generate"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -141,7 +141,7 @@ class CreateAudioMask:
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "createaudiomask"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/generate"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -194,7 +194,7 @@ class CreateGradientMask:
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "createmask"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/generate"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -229,7 +229,7 @@ class CreateFadeMask:
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "createfademask"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/generate"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -454,7 +454,7 @@ class CreateTextMask:
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK",)
|
||||
FUNCTION = "createtextmask"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/generate"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -534,7 +534,7 @@ class GrowMaskWithBlur:
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
RETURN_TYPES = ("MASK", "MASK",)
|
||||
RETURN_NAMES = ("mask", "mask_inverted",)
|
||||
@ -602,7 +602,7 @@ class ColorToMask:
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "clip"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -659,7 +659,7 @@ class ConditioningMultiCombine:
|
||||
RETURN_TYPES = ("CONDITIONING", "INT")
|
||||
RETURN_NAMES = ("combined", "inputcount")
|
||||
FUNCTION = "combine"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/conditioning"
|
||||
|
||||
def combine(self, inputcount, **kwargs):
|
||||
cond_combine_node = nodes.ConditioningCombine()
|
||||
@ -697,7 +697,7 @@ class ConditioningSetMaskAndCombine:
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||
FUNCTION = "append"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/conditioning"
|
||||
|
||||
def append(self, positive_1, negative_1, positive_2, negative_2, mask_1, mask_2, set_cond_area, mask_1_strength, mask_2_strength):
|
||||
c = []
|
||||
@ -743,7 +743,7 @@ class ConditioningSetMaskAndCombine3:
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||
FUNCTION = "append"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/conditioning"
|
||||
|
||||
def append(self, positive_1, negative_1, positive_2, positive_3, negative_2, negative_3, mask_1, mask_2, mask_3, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength):
|
||||
c = []
|
||||
@ -799,7 +799,7 @@ class ConditioningSetMaskAndCombine4:
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||
FUNCTION = "append"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/conditioning"
|
||||
|
||||
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, negative_2, negative_3, negative_4, mask_1, mask_2, mask_3, mask_4, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength):
|
||||
c = []
|
||||
@ -865,7 +865,7 @@ class ConditioningSetMaskAndCombine5:
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
|
||||
RETURN_NAMES = ("combined_positive", "combined_negative",)
|
||||
FUNCTION = "append"
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking/conditioning"
|
||||
|
||||
def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, positive_5, negative_2, negative_3, negative_4, negative_5, mask_1, mask_2, mask_3, mask_4, mask_5, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength, mask_5_strength):
|
||||
c = []
|
||||
@ -1034,7 +1034,7 @@ class ColorMatch:
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "KJNodes"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("image",)
|
||||
@ -1230,8 +1230,6 @@ class ImageBatchTestPattern:
|
||||
FUNCTION = "generatetestpattern"
|
||||
CATEGORY = "KJNodes"
|
||||
|
||||
|
||||
|
||||
def generatetestpattern(self, batch_size, start_from, width, height):
|
||||
out = []
|
||||
# Generate the sequential numbers for each image
|
||||
@ -1270,7 +1268,176 @@ class ImageBatchTestPattern:
|
||||
out.append(image)
|
||||
|
||||
return (torch.cat(out, dim=0),)
|
||||
|
||||
|
||||
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
||||
|
||||
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
||||
|
||||
class BatchCropFromMask:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"masks": ("MASK",),
|
||||
"bbox_size": ("INT", {"default": 256, "min": 64, "max": 1024, "step": 8}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (
|
||||
"IMAGE",
|
||||
"IMAGE",
|
||||
"BBOX",
|
||||
)
|
||||
RETURN_NAMES = (
|
||||
"original_images",
|
||||
"cropped_images",
|
||||
"bboxes",
|
||||
)
|
||||
FUNCTION = "crop"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def crop(self, masks, original_images, bbox_size):
|
||||
bounding_boxes = []
|
||||
cropped_images = []
|
||||
|
||||
for mask, img in zip(masks, original_images):
|
||||
_mask = tensor2pil(mask)[0]
|
||||
|
||||
# Calculate bounding box coordinates
|
||||
non_zero_indices = np.nonzero(np.array(_mask))
|
||||
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||
|
||||
# Calculate center of bounding box
|
||||
center_x = (max_x + min_x) // 2
|
||||
center_y = (max_y + min_y) // 2
|
||||
|
||||
# Create fixed-size bounding box around center
|
||||
half_box_size = bbox_size // 2
|
||||
min_x = center_x - half_box_size
|
||||
max_x = center_x + half_box_size
|
||||
min_y = center_y - half_box_size
|
||||
max_y = center_y + half_box_size
|
||||
|
||||
# Check if the bounding box dimensions go outside the image dimensions
|
||||
if min_x < 0:
|
||||
max_x -= min_x
|
||||
min_x = 0
|
||||
if max_x > img.shape[1]:
|
||||
min_x -= max_x - img.shape[1]
|
||||
max_x = img.shape[1]
|
||||
if min_y < 0:
|
||||
max_y -= min_y
|
||||
min_y = 0
|
||||
if max_y > img.shape[0]:
|
||||
min_y -= max_y - img.shape[0]
|
||||
max_y = img.shape[0]
|
||||
|
||||
# Append bounding box coordinates
|
||||
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
|
||||
|
||||
# Crop the image from the bounding box
|
||||
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||
cropped_images.append(cropped_img)
|
||||
cropped_out = torch.stack(cropped_images, dim=0)
|
||||
|
||||
return (original_images, cropped_out, bounding_boxes,)
|
||||
|
||||
|
||||
def bbox_to_region(bbox, target_size=None):
|
||||
bbox = bbox_check(bbox, target_size)
|
||||
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
|
||||
|
||||
def bbox_check(bbox, target_size=None):
|
||||
if not target_size:
|
||||
return bbox
|
||||
|
||||
new_bbox = (
|
||||
bbox[0],
|
||||
bbox[1],
|
||||
min(target_size[0] - bbox[0], bbox[2]),
|
||||
min(target_size[1] - bbox[1], bbox[3]),
|
||||
)
|
||||
return new_bbox
|
||||
|
||||
class BatchUncrop:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"cropped_images": ("IMAGE",),
|
||||
"bboxes": ("BBOX",),
|
||||
"border_blending": (
|
||||
"FLOAT",
|
||||
{"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "uncrop"
|
||||
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def uncrop(self, original_images, cropped_images, bboxes, border_blending):
|
||||
def inset_border(image, border_width=20, border_color=(0)):
|
||||
width, height = image.size
|
||||
bordered_image = Image.new(image.mode, (width, height), border_color)
|
||||
bordered_image.paste(image, (0, 0))
|
||||
draw = ImageDraw.Draw(bordered_image)
|
||||
draw.rectangle(
|
||||
(0, 0, width - 1, height - 1), outline=border_color, width=border_width
|
||||
)
|
||||
return bordered_image
|
||||
|
||||
if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes):
|
||||
raise ValueError("The number of images, crop_images, and bboxes should be the same")
|
||||
|
||||
input_images = tensor2pil(original_images)
|
||||
crop_imgs = tensor2pil(cropped_images)
|
||||
out_images = []
|
||||
for i in range(len(input_images)):
|
||||
img = input_images[i]
|
||||
crop = crop_imgs[i]
|
||||
bbox = bboxes[i]
|
||||
|
||||
# uncrop the image based on the bounding box
|
||||
bb_x, bb_y, bb_width, bb_height = bbox
|
||||
|
||||
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||
|
||||
crop_img = crop.convert("RGB")
|
||||
|
||||
if border_blending > 1.0:
|
||||
border_blending = 1.0
|
||||
elif border_blending < 0.0:
|
||||
border_blending = 0.0
|
||||
|
||||
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
|
||||
|
||||
blend = img.convert("RGBA")
|
||||
mask = Image.new("L", img.size, 0)
|
||||
|
||||
mask_block = Image.new("L", (bb_width, bb_height), 255)
|
||||
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0))
|
||||
|
||||
mask.paste(mask_block, paste_region)
|
||||
blend.paste(crop_img, paste_region)
|
||||
|
||||
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
|
||||
|
||||
blend.putalpha(mask)
|
||||
img = Image.alpha_composite(img.convert("RGBA"), blend)
|
||||
out_images.append(img.convert("RGB"))
|
||||
|
||||
return (pil2tensor(out_images),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"INTConstant": INTConstant,
|
||||
"FloatConstant": FloatConstant,
|
||||
@ -1298,7 +1465,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ImageGridComposite3x3": ImageGridComposite3x3,
|
||||
"ImageConcanate": ImageConcanate,
|
||||
"ImageBatchTestPattern": ImageBatchTestPattern,
|
||||
"ReplaceImagesInBatch": ReplaceImagesInBatch
|
||||
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
||||
"BatchCropFromMask": BatchCropFromMask,
|
||||
"BatchUncrop": BatchUncrop,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"INTConstant": "INT Constant",
|
||||
@ -1326,5 +1495,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageGridComposite3x3": "ImageGridComposite3x3",
|
||||
"ImageConcanate": "ImageConcanate",
|
||||
"ImageBatchTestPattern": "ImageBatchTestPattern",
|
||||
"ReplaceImagesInBatch": "ReplaceImagesInBatch"
|
||||
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
||||
"BatchCropFromMask": "BatchCropFromMask",
|
||||
"BatchUncrop": "BatchUncrop",
|
||||
}
|
||||
39
utility.py
Normal file
39
utility.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Union, List
|
||||
|
||||
# Utility functions from mtb nodes: https://github.com/melMass/comfy_mtb
|
||||
def pil2tensor(image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:
|
||||
if isinstance(image, list):
|
||||
return torch.cat([pil2tensor(img) for img in image], dim=0)
|
||||
|
||||
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||
|
||||
|
||||
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
|
||||
if isinstance(img_np, list):
|
||||
return torch.cat([np2tensor(img) for img in img_np], dim=0)
|
||||
|
||||
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
|
||||
|
||||
|
||||
def tensor2np(tensor: torch.Tensor):
|
||||
if len(tensor.shape) == 3: # Single image
|
||||
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
||||
else: # Batch of images
|
||||
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
|
||||
|
||||
def tensor2pil(image: torch.Tensor) -> List[Image.Image]:
|
||||
batch_count = image.size(0) if len(image.shape) > 3 else 1
|
||||
if batch_count > 1:
|
||||
out = []
|
||||
for i in range(batch_count):
|
||||
out.extend(tensor2pil(image[i]))
|
||||
return out
|
||||
|
||||
return [
|
||||
Image.fromarray(
|
||||
np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
|
||||
)
|
||||
]
|
||||
Loading…
x
Reference in New Issue
Block a user