diff --git a/nodes.py b/nodes.py index 5e3d68f..5fcb141 100644 --- a/nodes.py +++ b/nodes.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F import scipy.ndimage import numpy as np +from PIL import ImageColor, Image from nodes import MAX_RESOLUTION @@ -107,8 +108,83 @@ class PlotNode: def plot(self, start, max_frames): result = start + max_frames - return (result,) + return (result,) + +def setup_color_to_correct_type(color): + if color is None: + return None + return color if isinstance(color, list) else ImageColor.getcolor(color, "RGB") + +class ColorClipToMask: + RETURN_TYPES = ("MASK",) + FUNCTION = "clip" + CATEGORY = "KJNodes" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE",), + "target": (["TO_BLACK","TO_WHITE"],{"default": "TO_BLACK"}), + "complement": (["TO_BLACK","TO_WHITE"],{"default": "TO_WHITE"}), + "color": ("COLOR",), + "threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}), + }, + } + # def color_clip(self, image, color, threshold, target, complement): + # image = self.clip(image, color, threshold, target, complement) + # return (image,) + + def clip(self, images, color, threshold, target, complement, color_a=None, color_b=None): + color = setup_color_to_correct_type(color) + color_a = setup_color_to_correct_type(color_a) + color_b = setup_color_to_correct_type(color_b) + + images = 255. * images.cpu().numpy() + images = np.clip(images, 0, 255).astype(np.uint8) + images = [Image.fromarray(image) for image in images] + images = [np.array(image) for image in images] + + color = np.array(color) # Convert color to a numpy array + + def select_color(selection): + match selection: + case "TO_BLACK": + return [0, 0, 0] + case "TO_WHITE": + return [255, 255, 255] + case "_": + return None + + complement_color = select_color(complement) + target_color = select_color(target) + + new_images = [] + for image in images: + match target: + case "NOTHING": + new_image = np.array(image, copy=True) + case _: + new_image = np.full_like(image, target_color) + + color_distances = np.linalg.norm(image - color, axis=-1) + complement_indexes = color_distances <= threshold + + match complement: + case "NOTHING": + new_image[complement_indexes] = image[complement_indexes] + case _: + new_image[complement_indexes] = complement_color + + new_images.append(new_image) + + new_images = np.array(new_images).astype(np.float32) / 255.0 + new_images = torch.from_numpy(new_images) + new_images = new_images.permute(0, 3, 1, 2) + + return (new_images,) + class ConditioningMultiCombine: @classmethod def INPUT_TYPES(s): @@ -202,10 +278,12 @@ NODE_CLASS_MAPPINGS = { "ConditioningMultiCombine": ConditioningMultiCombine, "ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine, "GrowMaskWithBlur": GrowMaskWithBlur, + "ColorClipToMask": ColorClipToMask, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", "ConditioningMultiCombine": "Conditioning Multi Combine", "ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine", "GrowMaskWithBlur": "GrowMaskWithBlur", + "ColorClipToMask": "ColorClipToMask", } \ No newline at end of file