diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 0caf324..937ffce 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -28,6 +28,7 @@ try: from server import PromptServer except: PromptServer = None +from concurrent.futures import ThreadPoolExecutor script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -72,6 +73,7 @@ class ColorMatch: }, "optional": { "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "multithread": ("BOOLEAN", {"default": True}), } } @@ -93,37 +95,41 @@ https://github.com/hahnec/color-matcher/ """ - def colormatch(self, image_ref, image_target, method, strength=1.0): + def colormatch(self, image_ref, image_target, method, strength=1.0, multithread=True): try: from color_matcher import ColorMatcher except: raise Exception("Can't import color-matcher, did you install requirements.txt? Manual install: pip install color-matcher") - cm = ColorMatcher() + image_ref = image_ref.cpu() image_target = image_target.cpu() batch_size = image_target.size(0) - out = [] + images_target = image_target.squeeze() images_ref = image_ref.squeeze() image_ref_np = images_ref.numpy() images_target_np = images_target.numpy() - if image_ref.size(0) > 1 and image_ref.size(0) != batch_size: - raise ValueError("ColorMatch: Use either single reference image or a matching batch of reference images.") - - for i in range(batch_size): - image_target_np = images_target_np if batch_size == 1 else images_target[i].numpy() + def process(i): + cm = ColorMatcher() + image_target_np_i = images_target_np if batch_size == 1 else images_target[i].numpy() image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy() try: - image_result = cm.transfer(src=image_target_np, ref=image_ref_np_i, method=method) - except BaseException as e: - print(f"Error occurred during transfer: {e}") - break - # Apply the strength multiplier - image_result = image_target_np + strength * (image_result - image_target_np) - out.append(torch.from_numpy(image_result)) - + image_result = cm.transfer(src=image_target_np_i, ref=image_ref_np_i, method=method) + image_result = image_target_np_i + strength * (image_result - image_target_np_i) + return torch.from_numpy(image_result) + except Exception as e: + print(f"Thread {i} error: {e}") + return torch.from_numpy(image_target_np_i) # fallback + + if multithread and batch_size > 1: + max_threads = min(os.cpu_count() or 1, batch_size) + with ThreadPoolExecutor(max_workers=max_threads) as executor: + out = list(executor.map(process, range(batch_size))) + else: + out = [process(i) for i in range(batch_size)] + out = torch.stack(out, dim=0).to(torch.float32) out.clamp_(0, 1) return (out,)