fix ColorClipMask

works with batches now
This commit is contained in:
kijai 2023-10-02 15:21:14 +03:00
parent b72602c589
commit 394fc44037

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
import scipy.ndimage import scipy.ndimage
import numpy as np import numpy as np
from PIL import ImageColor, Image from PIL import ImageColor, Image
from colorama import Fore, Back, Style
from nodes import MAX_RESOLUTION from nodes import MAX_RESOLUTION
@ -109,13 +110,8 @@ class PlotNode:
def plot(self, start, max_frames): def plot(self, start, max_frames):
result = start + max_frames result = start + max_frames
return (result,) return (result,)
def setup_color_to_correct_type(color):
if color is None:
return None
return color if isinstance(color, list) else ImageColor.getcolor(color, "RGB")
class ColorClipToMask: class ColorToMask:
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "clip" FUNCTION = "clip"
@ -126,64 +122,40 @@ class ColorClipToMask:
return { return {
"required": { "required": {
"images": ("IMAGE",), "images": ("IMAGE",),
"target": (["TO_BLACK","TO_WHITE"],{"default": "TO_BLACK"}), "invert": ("BOOLEAN", {"default": False}),
"complement": (["TO_BLACK","TO_WHITE"],{"default": "TO_WHITE"}), "red": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"color": ("COLOR",), "green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}), "threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}),
}, },
} }
# def color_clip(self, image, color, threshold, target, complement):
# image = self.clip(image, color, threshold, target, complement)
# return (image,)
def clip(self, images, color, threshold, target, complement, color_a=None, color_b=None):
color = setup_color_to_correct_type(color)
color_a = setup_color_to_correct_type(color_a)
color_b = setup_color_to_correct_type(color_b)
def clip(self, images, red, green, blue, threshold, invert):
color = np.array([red, green, blue])
images = 255. * images.cpu().numpy() images = 255. * images.cpu().numpy()
images = np.clip(images, 0, 255).astype(np.uint8) images = np.clip(images, 0, 255).astype(np.uint8)
images = [Image.fromarray(image) for image in images] images = [Image.fromarray(image) for image in images]
images = [np.array(image) for image in images] images = [np.array(image) for image in images]
color = np.array(color) # Convert color to a numpy array black = [0, 0, 0]
white = [255, 255, 255]
def select_color(selection): if invert:
match selection: black, white = white, black
case "TO_BLACK":
return [0, 0, 0]
case "TO_WHITE":
return [255, 255, 255]
case "_":
return None
complement_color = select_color(complement)
target_color = select_color(target)
new_images = [] new_images = []
for image in images: for image in images:
match target: new_image = np.full_like(image, black)
case "NOTHING":
new_image = np.array(image, copy=True)
case _:
new_image = np.full_like(image, target_color)
color_distances = np.linalg.norm(image - color, axis=-1) color_distances = np.linalg.norm(image - color, axis=-1)
complement_indexes = color_distances <= threshold complement_indexes = color_distances <= threshold
match complement: new_image[complement_indexes] = white
case "NOTHING":
new_image[complement_indexes] = image[complement_indexes]
case _:
new_image[complement_indexes] = complement_color
new_images.append(new_image) new_images.append(new_image)
new_images = np.array(new_images).astype(np.float32) / 255.0 new_images = np.array(new_images).astype(np.float32) / 255.0
new_images = torch.from_numpy(new_images) new_images = torch.from_numpy(new_images).permute(3, 0, 1, 2)
new_images = new_images.permute(0, 3, 1, 2) return new_images
return (new_images,)
class ConditioningMultiCombine: class ConditioningMultiCombine:
@classmethod @classmethod
@ -278,12 +250,12 @@ NODE_CLASS_MAPPINGS = {
"ConditioningMultiCombine": ConditioningMultiCombine, "ConditioningMultiCombine": ConditioningMultiCombine,
"ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine, "ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine,
"GrowMaskWithBlur": GrowMaskWithBlur, "GrowMaskWithBlur": GrowMaskWithBlur,
"ColorClipToMask": ColorClipToMask, "ColorToMask": ColorToMask,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant", "INTConstant": "INT Constant",
"ConditioningMultiCombine": "Conditioning Multi Combine", "ConditioningMultiCombine": "Conditioning Multi Combine",
"ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine", "ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine",
"GrowMaskWithBlur": "GrowMaskWithBlur", "GrowMaskWithBlur": "GrowMaskWithBlur",
"ColorClipToMask": "ColorClipToMask", "ColorToMask": "ColorToMask",
} }