mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-22 03:04:29 +08:00
Handle ColorMatch inputs better
This commit is contained in:
parent
213acc2adb
commit
defe5940d7
25
nodes.py
25
nodes.py
@ -920,7 +920,7 @@ class ColorMatch:
|
|||||||
], {
|
], {
|
||||||
"default": 'mkl'
|
"default": 'mkl'
|
||||||
}),
|
}),
|
||||||
"use_only_first": ("BOOLEAN", {"default": False}),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -930,29 +930,28 @@ class ColorMatch:
|
|||||||
RETURN_NAMES = ("image",)
|
RETURN_NAMES = ("image",)
|
||||||
FUNCTION = "colormatch"
|
FUNCTION = "colormatch"
|
||||||
|
|
||||||
def colormatch(self, image_ref, image_target, use_only_first, method):
|
def colormatch(self, image_ref, image_target, method):
|
||||||
cm = ColorMatcher()
|
cm = ColorMatcher()
|
||||||
batch_size = image_target.shape[0]
|
batch_size = image_target.size(0)
|
||||||
out = []
|
out = []
|
||||||
images_target = image_target.squeeze()
|
images_target = image_target.squeeze()
|
||||||
images_ref = image_ref.squeeze()
|
images_ref = image_ref.squeeze()
|
||||||
print(image_ref.shape)
|
|
||||||
|
|
||||||
|
image_ref_np = images_ref.numpy()
|
||||||
|
images_target_np = images_target.numpy()
|
||||||
|
|
||||||
|
if image_ref.size(0) > 1 and image_ref.size(0) != batch_size:
|
||||||
|
raise ValueError("ColorMatch: Use either single reference image or a matching batch of reference images.")
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if use_only_first:
|
image_target_np = images_target_np if batch_size == 1 else images_target[i].numpy()
|
||||||
image_ref = images_ref.numpy()
|
image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy()
|
||||||
else:
|
|
||||||
print("MULTIPLE IMAGES")
|
|
||||||
image_ref = images_ref[i].numpy()
|
|
||||||
image_target = images_target[i].numpy()
|
|
||||||
try:
|
try:
|
||||||
image_result = cm.transfer(src=image_target, ref=image_ref, method=method)
|
image_result = cm.transfer(src=image_target_np, ref=image_ref_np_i, method=method)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
print(f"Error occurred during transfer: {e}")
|
print(f"Error occurred during transfer: {e}")
|
||||||
break
|
break
|
||||||
out.append(torch.from_numpy(image_result))
|
out.append(torch.from_numpy(image_result))
|
||||||
|
|
||||||
return (torch.stack(out, dim=0).to(torch.float32), )
|
return (torch.stack(out, dim=0).to(torch.float32), )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user