Add Split/Merge image channels -nodes

This commit is contained in:
kijai 2024-05-01 10:51:49 +03:00
parent 2405a6192a
commit edae7ef9d2
2 changed files with 67 additions and 0 deletions

View File

@ -56,6 +56,8 @@ NODE_CLASS_MAPPINGS = {
"ImagePass": ImagePass,
"ImagePadForOutpaintMasked": ImagePadForOutpaintMasked,
"ImageAndMaskPreview": ImageAndMaskPreview,
"SplitImageChannels": SplitImageChannels,
"MergeImageChannels": MergeImageChannels,
#batch cropping
"BatchCropFromMask": BatchCropFromMask,
"BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced,
@ -194,6 +196,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"FloatToSigmas": "Float To Sigmas",
"CustomSigmas": "Custom Sigmas",
"ImagePass": "ImagePass",
"SplitImageChannels": "Split Image Channels",
"MergeImageChannels": "Merge Image Channels",
#curve nodes
"SplineEditor": "Spline Editor",
"CreateShapeMaskOnPath": "Create Shape Mask On Path",

View File

@ -3558,7 +3558,70 @@ Remaps the image values to the specified range.
if clamp:
image = torch.clamp(image, min=0.0, max=1.0)
return (image, )
class SplitImageChannels:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "MASK")
RETURN_NAMES = ("red", "green", "blue", "mask")
FUNCTION = "split"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Splits image channels into images where the selected channel
is repeated for all channels, and the alpha as a mask.
"""
def split(self, image):
red = image[:, :, :, 0:1] # Red channel
green = image[:, :, :, 1:2] # Green channel
blue = image[:, :, :, 2:3] # Blue channel
alpha = image[:, :, :, 3:4] # Alpha channel
alpha = alpha.squeeze(-1)
# Repeat the selected channel for all channels
red = torch.cat([red, red, red], dim=3)
green = torch.cat([green, green, green], dim=3)
blue = torch.cat([blue, blue, blue], dim=3)
return (red, green, blue, alpha)
class MergeImageChannels:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"red": ("IMAGE",),
"green": ("IMAGE",),
"blue": ("IMAGE",),
},
"optional": {
"mask": ("MASK", {"default": None}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "merge"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Merges channel data into an image.
"""
def merge(self, red, green, blue, alpha=None):
image = torch.stack([
red[..., 0, None], # Red channel
green[..., 1, None], # Green channel
blue[..., 2, None] # Blue channel
], dim=-1)
image = image.squeeze(-2)
if alpha is not None:
image = torch.cat([image, alpha], dim=-1)
return (image,)
class CameraPoseVisualizer:
@classmethod