mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 05:15:05 +08:00
Added multithread option to ColorMatch node
This commit is contained in:
parent
ad37ce656c
commit
0d2334de6d
@ -28,6 +28,7 @@ try:
|
|||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
except:
|
except:
|
||||||
PromptServer = None
|
PromptServer = None
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
@ -72,6 +73,7 @@ class ColorMatch:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"multithread": ("BOOLEAN", {"default": True}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,37 +95,41 @@ https://github.com/hahnec/color-matcher/
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def colormatch(self, image_ref, image_target, method, strength=1.0):
|
def colormatch(self, image_ref, image_target, method, strength=1.0, multithread=True):
|
||||||
try:
|
try:
|
||||||
from color_matcher import ColorMatcher
|
from color_matcher import ColorMatcher
|
||||||
except:
|
except:
|
||||||
raise Exception("Can't import color-matcher, did you install requirements.txt? Manual install: pip install color-matcher")
|
raise Exception("Can't import color-matcher, did you install requirements.txt? Manual install: pip install color-matcher")
|
||||||
cm = ColorMatcher()
|
|
||||||
image_ref = image_ref.cpu()
|
image_ref = image_ref.cpu()
|
||||||
image_target = image_target.cpu()
|
image_target = image_target.cpu()
|
||||||
batch_size = image_target.size(0)
|
batch_size = image_target.size(0)
|
||||||
out = []
|
|
||||||
images_target = image_target.squeeze()
|
images_target = image_target.squeeze()
|
||||||
images_ref = image_ref.squeeze()
|
images_ref = image_ref.squeeze()
|
||||||
|
|
||||||
image_ref_np = images_ref.numpy()
|
image_ref_np = images_ref.numpy()
|
||||||
images_target_np = images_target.numpy()
|
images_target_np = images_target.numpy()
|
||||||
|
|
||||||
if image_ref.size(0) > 1 and image_ref.size(0) != batch_size:
|
def process(i):
|
||||||
raise ValueError("ColorMatch: Use either single reference image or a matching batch of reference images.")
|
cm = ColorMatcher()
|
||||||
|
image_target_np_i = images_target_np if batch_size == 1 else images_target[i].numpy()
|
||||||
for i in range(batch_size):
|
|
||||||
image_target_np = images_target_np if batch_size == 1 else images_target[i].numpy()
|
|
||||||
image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy()
|
image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy()
|
||||||
try:
|
try:
|
||||||
image_result = cm.transfer(src=image_target_np, ref=image_ref_np_i, method=method)
|
image_result = cm.transfer(src=image_target_np_i, ref=image_ref_np_i, method=method)
|
||||||
except BaseException as e:
|
image_result = image_target_np_i + strength * (image_result - image_target_np_i)
|
||||||
print(f"Error occurred during transfer: {e}")
|
return torch.from_numpy(image_result)
|
||||||
break
|
except Exception as e:
|
||||||
# Apply the strength multiplier
|
print(f"Thread {i} error: {e}")
|
||||||
image_result = image_target_np + strength * (image_result - image_target_np)
|
return torch.from_numpy(image_target_np_i) # fallback
|
||||||
out.append(torch.from_numpy(image_result))
|
|
||||||
|
if multithread and batch_size > 1:
|
||||||
|
max_threads = min(os.cpu_count() or 1, batch_size)
|
||||||
|
with ThreadPoolExecutor(max_workers=max_threads) as executor:
|
||||||
|
out = list(executor.map(process, range(batch_size)))
|
||||||
|
else:
|
||||||
|
out = [process(i) for i in range(batch_size)]
|
||||||
|
|
||||||
out = torch.stack(out, dim=0).to(torch.float32)
|
out = torch.stack(out, dim=0).to(torch.float32)
|
||||||
out.clamp_(0, 1)
|
out.clamp_(0, 1)
|
||||||
return (out,)
|
return (out,)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user