Add ColorMatch

This commit is contained in:
kijai 2023-10-31 17:47:00 +02:00
parent c2edc00b37
commit 59566706ac

110
nodes.py
View File

@ -848,7 +848,113 @@ class SomethingToString:
else:
return
return (stringified,)
from nodes import EmptyLatentImage
class EmptyLatentImagePresets:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dimensions": (
[ '768 x 512',
'960 x 512',
'1024 x 512',
'1536 x 640',
'1536 x 640',
'1344 x 768',
'1216 x 832',
'1152 x 896',
'1024 x 1024',
],
{
"default": '1024 x 1024'
}),
"invert": ("BOOLEAN", {"default": False}),
"batch_size": ("INT", {
"default": 1,
"min": 1,
"max": 4096
}),
},
}
RETURN_TYPES = ("LATENT", "INT", "INT")
RETURN_NAMES = ("Latent", "Width", "Height")
FUNCTION = "generate"
CATEGORY = "KJNodes"
def generate(self, dimensions, invert, batch_size):
result = [x.strip() for x in dimensions.split('x')]
if invert:
width = int(result[1].split(' ')[0])
height = int(result[0])
else:
width = int(result[0])
height = int(result[1].split(' ')[0])
latent = EmptyLatentImage().generate(width, height, batch_size)[0]
return (latent, int(width), int(height),)
#https://github.com/hahnec/color-matcher/
from color_matcher import ColorMatcher
from color_matcher.normalizer import Normalizer
class ColorMatch:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image_ref": ("IMAGE",),
"image_target": ("IMAGE",),
"method": (
[
'mkl',
'hm',
'reinhard',
'mvgd',
'hm-mvgd-hm',
'hm-mkl-hm',
], {
"default": 'mkl'
}),
"use_only_first": ("BOOLEAN", {"default": False}),
},
}
CATEGORY = "KJNodes"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "colormatch"
def colormatch(self, image_ref, image_target, use_only_first, method):
cm = ColorMatcher()
batch_size = image_target.shape[0]
out = []
images_target = image_target.squeeze()
images_ref = image_ref.squeeze()
print(image_ref.shape)
for i in range(batch_size):
if use_only_first:
image_ref = images_ref.numpy()
else:
print("MULTIPLE IMAGES")
image_ref = images_ref[i].numpy()
image_target = images_target[i].numpy()
try:
image_result = cm.transfer(src=image_target, ref=image_ref, method=method)
except BaseException as e:
print(f"Error occurred during transfer: {e}")
break
out.append(torch.from_numpy(image_result))
return (torch.stack(out, dim=0).to(torch.float32), )
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
"ConditioningMultiCombine": ConditioningMultiCombine,
@ -866,6 +972,8 @@ NODE_CLASS_MAPPINGS = {
"VRAM_Debug" : VRAM_Debug,
"SomethingToString" : SomethingToString,
"CrossFadeImages": CrossFadeImages,
"EmptyLatentImagePresets": EmptyLatentImagePresets,
"ColorMatch": ColorMatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -883,4 +991,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"VRAM_Debug" : "VRAM Debug",
"CrossFadeImages": "CrossFadeImages",
"SomethingToString": "SomethingToString",
"EmptyLatentImagePresets": "EmptyLatentImagePresets",
"ColorMatch": "ColorMatch",
}