From 59566706ac9ba80a49f9be96c4b00bf7eaaf663d Mon Sep 17 00:00:00 2001 From: kijai Date: Tue, 31 Oct 2023 17:47:00 +0200 Subject: [PATCH] Add ColorMatch --- nodes.py | 110 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/nodes.py b/nodes.py index 0731dcc..ab81e3e 100644 --- a/nodes.py +++ b/nodes.py @@ -848,7 +848,113 @@ class SomethingToString: else: return return (stringified,) + +from nodes import EmptyLatentImage + +class EmptyLatentImagePresets: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "dimensions": ( + [ '768 x 512', + '960 x 512', + '1024 x 512', + '1536 x 640', + '1536 x 640', + '1344 x 768', + '1216 x 832', + '1152 x 896', + '1024 x 1024', + ], + { + "default": '1024 x 1024' + }), + + "invert": ("BOOLEAN", {"default": False}), + "batch_size": ("INT", { + "default": 1, + "min": 1, + "max": 4096 + }), + }, + } + + RETURN_TYPES = ("LATENT", "INT", "INT") + RETURN_NAMES = ("Latent", "Width", "Height") + FUNCTION = "generate" + CATEGORY = "KJNodes" + + def generate(self, dimensions, invert, batch_size): + result = [x.strip() for x in dimensions.split('x')] + + if invert: + width = int(result[1].split(' ')[0]) + height = int(result[0]) + else: + width = int(result[0]) + height = int(result[1].split(' ')[0]) + latent = EmptyLatentImage().generate(width, height, batch_size)[0] + + return (latent, int(width), int(height),) + +#https://github.com/hahnec/color-matcher/ +from color_matcher import ColorMatcher +from color_matcher.normalizer import Normalizer + +class ColorMatch: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image_ref": ("IMAGE",), + "image_target": ("IMAGE",), + "method": ( + [ + 'mkl', + 'hm', + 'reinhard', + 'mvgd', + 'hm-mvgd-hm', + 'hm-mkl-hm', + ], { + "default": 'mkl' + }), + "use_only_first": ("BOOLEAN", {"default": False}), + }, + } + CATEGORY = "KJNodes" + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + FUNCTION = "colormatch" + + def colormatch(self, image_ref, image_target, use_only_first, method): + cm = ColorMatcher() + batch_size = image_target.shape[0] + out = [] + images_target = image_target.squeeze() + images_ref = image_ref.squeeze() + print(image_ref.shape) + + + for i in range(batch_size): + if use_only_first: + image_ref = images_ref.numpy() + else: + print("MULTIPLE IMAGES") + image_ref = images_ref[i].numpy() + image_target = images_target[i].numpy() + try: + image_result = cm.transfer(src=image_target, ref=image_ref, method=method) + except BaseException as e: + print(f"Error occurred during transfer: {e}") + break + out.append(torch.from_numpy(image_result)) + + return (torch.stack(out, dim=0).to(torch.float32), ) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "ConditioningMultiCombine": ConditioningMultiCombine, @@ -866,6 +972,8 @@ NODE_CLASS_MAPPINGS = { "VRAM_Debug" : VRAM_Debug, "SomethingToString" : SomethingToString, "CrossFadeImages": CrossFadeImages, + "EmptyLatentImagePresets": EmptyLatentImagePresets, + "ColorMatch": ColorMatch, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -883,4 +991,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "VRAM_Debug" : "VRAM Debug", "CrossFadeImages": "CrossFadeImages", "SomethingToString": "SomethingToString", + "EmptyLatentImagePresets": "EmptyLatentImagePresets", + "ColorMatch": "ColorMatch", } \ No newline at end of file