Handle ColorMatch inputs better

This commit is contained in:
kijai 2023-11-01 00:11:15 +02:00
parent 213acc2adb
commit defe5940d7

View File

@ -920,7 +920,7 @@ class ColorMatch:
], {
"default": 'mkl'
}),
"use_only_first": ("BOOLEAN", {"default": False}),
},
}
@ -930,29 +930,28 @@ class ColorMatch:
RETURN_NAMES = ("image",)
FUNCTION = "colormatch"
def colormatch(self, image_ref, image_target, use_only_first, method):
def colormatch(self, image_ref, image_target, method):
cm = ColorMatcher()
batch_size = image_target.shape[0]
batch_size = image_target.size(0)
out = []
images_target = image_target.squeeze()
images_ref = image_ref.squeeze()
print(image_ref.shape)
image_ref_np = images_ref.numpy()
images_target_np = images_target.numpy()
if image_ref.size(0) > 1 and image_ref.size(0) != batch_size:
raise ValueError("ColorMatch: Use either single reference image or a matching batch of reference images.")
for i in range(batch_size):
if use_only_first:
image_ref = images_ref.numpy()
else:
print("MULTIPLE IMAGES")
image_ref = images_ref[i].numpy()
image_target = images_target[i].numpy()
image_target_np = images_target_np if batch_size == 1 else images_target[i].numpy()
image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy()
try:
image_result = cm.transfer(src=image_target, ref=image_ref, method=method)
image_result = cm.transfer(src=image_target_np, ref=image_ref_np_i, method=method)
except BaseException as e:
print(f"Error occurred during transfer: {e}")
break
out.append(torch.from_numpy(image_result))
return (torch.stack(out, dim=0).to(torch.float32), )
NODE_CLASS_MAPPINGS = {