diff --git a/nodes.py b/nodes.py index ab81e3e..82fa72d 100644 --- a/nodes.py +++ b/nodes.py @@ -920,7 +920,7 @@ class ColorMatch: ], { "default": 'mkl' }), - "use_only_first": ("BOOLEAN", {"default": False}), + }, } @@ -930,29 +930,28 @@ class ColorMatch: RETURN_NAMES = ("image",) FUNCTION = "colormatch" - def colormatch(self, image_ref, image_target, use_only_first, method): + def colormatch(self, image_ref, image_target, method): cm = ColorMatcher() - batch_size = image_target.shape[0] + batch_size = image_target.size(0) out = [] images_target = image_target.squeeze() images_ref = image_ref.squeeze() - print(image_ref.shape) - + image_ref_np = images_ref.numpy() + images_target_np = images_target.numpy() + + if image_ref.size(0) > 1 and image_ref.size(0) != batch_size: + raise ValueError("ColorMatch: Use either single reference image or a matching batch of reference images.") + for i in range(batch_size): - if use_only_first: - image_ref = images_ref.numpy() - else: - print("MULTIPLE IMAGES") - image_ref = images_ref[i].numpy() - image_target = images_target[i].numpy() + image_target_np = images_target_np if batch_size == 1 else images_target[i].numpy() + image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy() try: - image_result = cm.transfer(src=image_target, ref=image_ref, method=method) + image_result = cm.transfer(src=image_target_np, ref=image_ref_np_i, method=method) except BaseException as e: print(f"Error occurred during transfer: {e}") break out.append(torch.from_numpy(image_result)) - return (torch.stack(out, dim=0).to(torch.float32), ) NODE_CLASS_MAPPINGS = {