diff --git a/__init__.py b/__init__.py index dfbf668..0c435a1 100644 --- a/__init__.py +++ b/__init__.py @@ -56,6 +56,8 @@ NODE_CLASS_MAPPINGS = { "ImagePass": ImagePass, "ImagePadForOutpaintMasked": ImagePadForOutpaintMasked, "ImageAndMaskPreview": ImageAndMaskPreview, + "SplitImageChannels": SplitImageChannels, + "MergeImageChannels": MergeImageChannels, #batch cropping "BatchCropFromMask": BatchCropFromMask, "BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced, @@ -194,6 +196,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "FloatToSigmas": "Float To Sigmas", "CustomSigmas": "Custom Sigmas", "ImagePass": "ImagePass", + "SplitImageChannels": "Split Image Channels", + "MergeImageChannels": "Merge Image Channels", #curve nodes "SplineEditor": "Spline Editor", "CreateShapeMaskOnPath": "Create Shape Mask On Path", diff --git a/nodes/nodes.py b/nodes/nodes.py index 1d1e7e7..9c2bab9 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -3558,7 +3558,70 @@ Remaps the image values to the specified range. if clamp: image = torch.clamp(image, min=0.0, max=1.0) return (image, ) + +class SplitImageChannels: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + }, + } + RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "MASK") + RETURN_NAMES = ("red", "green", "blue", "mask") + FUNCTION = "split" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ +Splits image channels into images where the selected channel +is repeated for all channels, and the alpha as a mask. +""" + + def split(self, image): + red = image[:, :, :, 0:1] # Red channel + green = image[:, :, :, 1:2] # Green channel + blue = image[:, :, :, 2:3] # Blue channel + alpha = image[:, :, :, 3:4] # Alpha channel + alpha = alpha.squeeze(-1) + + # Repeat the selected channel for all channels + red = torch.cat([red, red, red], dim=3) + green = torch.cat([green, green, green], dim=3) + blue = torch.cat([blue, blue, blue], dim=3) + return (red, green, blue, alpha) + +class MergeImageChannels: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "red": ("IMAGE",), + "green": ("IMAGE",), + "blue": ("IMAGE",), + + }, + "optional": { + "mask": ("MASK", {"default": None}), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + FUNCTION = "merge" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ +Merges channel data into an image. +""" + + def merge(self, red, green, blue, alpha=None): + image = torch.stack([ + red[..., 0, None], # Red channel + green[..., 1, None], # Green channel + blue[..., 2, None] # Blue channel + ], dim=-1) + image = image.squeeze(-2) + if alpha is not None: + image = torch.cat([image, alpha], dim=-1) + return (image,) + class CameraPoseVisualizer: @classmethod