fix ColorClipMask

works with batches now
This commit is contained in:
kijai 2023-10-02 15:21:14 +03:00
parent b72602c589
commit 394fc44037

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
import scipy.ndimage
import numpy as np
from PIL import ImageColor, Image
from colorama import Fore, Back, Style
from nodes import MAX_RESOLUTION
@ -109,13 +110,8 @@ class PlotNode:
def plot(self, start, max_frames):
result = start + max_frames
return (result,)
def setup_color_to_correct_type(color):
if color is None:
return None
return color if isinstance(color, list) else ImageColor.getcolor(color, "RGB")
class ColorClipToMask:
class ColorToMask:
RETURN_TYPES = ("MASK",)
FUNCTION = "clip"
@ -126,64 +122,40 @@ class ColorClipToMask:
return {
"required": {
"images": ("IMAGE",),
"target": (["TO_BLACK","TO_WHITE"],{"default": "TO_BLACK"}),
"complement": (["TO_BLACK","TO_WHITE"],{"default": "TO_WHITE"}),
"color": ("COLOR",),
"invert": ("BOOLEAN", {"default": False}),
"red": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}),
},
}
# def color_clip(self, image, color, threshold, target, complement):
# image = self.clip(image, color, threshold, target, complement)
# return (image,)
def clip(self, images, color, threshold, target, complement, color_a=None, color_b=None):
color = setup_color_to_correct_type(color)
color_a = setup_color_to_correct_type(color_a)
color_b = setup_color_to_correct_type(color_b)
def clip(self, images, red, green, blue, threshold, invert):
color = np.array([red, green, blue])
images = 255. * images.cpu().numpy()
images = np.clip(images, 0, 255).astype(np.uint8)
images = [Image.fromarray(image) for image in images]
images = [np.array(image) for image in images]
color = np.array(color) # Convert color to a numpy array
def select_color(selection):
match selection:
case "TO_BLACK":
return [0, 0, 0]
case "TO_WHITE":
return [255, 255, 255]
case "_":
return None
complement_color = select_color(complement)
target_color = select_color(target)
black = [0, 0, 0]
white = [255, 255, 255]
if invert:
black, white = white, black
new_images = []
for image in images:
match target:
case "NOTHING":
new_image = np.array(image, copy=True)
case _:
new_image = np.full_like(image, target_color)
new_image = np.full_like(image, black)
color_distances = np.linalg.norm(image - color, axis=-1)
complement_indexes = color_distances <= threshold
match complement:
case "NOTHING":
new_image[complement_indexes] = image[complement_indexes]
case _:
new_image[complement_indexes] = complement_color
new_image[complement_indexes] = white
new_images.append(new_image)
new_images = np.array(new_images).astype(np.float32) / 255.0
new_images = torch.from_numpy(new_images)
new_images = new_images.permute(0, 3, 1, 2)
return (new_images,)
new_images = torch.from_numpy(new_images).permute(3, 0, 1, 2)
return new_images
class ConditioningMultiCombine:
@classmethod
@ -278,12 +250,12 @@ NODE_CLASS_MAPPINGS = {
"ConditioningMultiCombine": ConditioningMultiCombine,
"ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine,
"GrowMaskWithBlur": GrowMaskWithBlur,
"ColorClipToMask": ColorClipToMask,
"ColorToMask": ColorToMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
"ConditioningMultiCombine": "Conditioning Multi Combine",
"ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine",
"GrowMaskWithBlur": "GrowMaskWithBlur",
"ColorClipToMask": "ColorClipToMask",
"ColorToMask": "ColorToMask",
}