diff --git a/nodes.py b/nodes.py index ad2bfa5..32612f5 100644 --- a/nodes.py +++ b/nodes.py @@ -4,6 +4,9 @@ import torch.nn.functional as F import scipy.ndimage import numpy as np from PIL import ImageColor, Image, ImageDraw, ImageFont +from PIL.PngImagePlugin import PngInfo +import json +import re import os import librosa from scipy.special import erf @@ -11,6 +14,8 @@ from .fluid import Fluid import comfy.model_management import math from nodes import MAX_RESOLUTION +import folder_paths +from comfy.cli_args import args script_dir = os.path.dirname(os.path.abspath(__file__)) @@ -998,7 +1003,80 @@ class ColorMatch: break out.append(torch.from_numpy(image_result)) return (torch.stack(out, dim=0).to(torch.float32), ) + +class SaveImageWithAlpha: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "mask": ("MASK", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images_alpha" + + OUTPUT_NODE = True + + CATEGORY = "image" + + def save_images_alpha(self, images, mask, filename_prefix="ComfyUI_image_with_alpha", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + + def file_counter(): + max_counter = 0 + # Loop through the existing files + for existing_file in os.listdir(full_output_folder): + # Check if the file matches the expected format + match = re.fullmatch(f"{filename}_(\d+)_?\.[a-zA-Z0-9]+", existing_file) + if match: + # Extract the numeric portion of the filename + file_counter = int(match.group(1)) + # Update the maximum counter value if necessary + if file_counter > max_counter: + max_counter = file_counter + return max_counter + + for image, alpha in zip(images, mask): + i = 255. * image.cpu().numpy() + a = 255. * alpha.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + + if a.shape == img.size[::-1]: # Check if the mask has the same size as the image + print("Applying mask") + a = np.clip(a, 0, 255).astype(np.uint8) + img.putalpha(Image.fromarray(a, mode='L')) + else: + raise ValueError("SaveImageWithAlpha: Mask size does not match") + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + # Increment the counter by 1 to get the next available value + counter = file_counter() + 1 + file = f"{filename}_{counter:05}.png" + img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + + return { "ui": { "images": results } } + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -1019,7 +1097,8 @@ NODE_CLASS_MAPPINGS = { "CrossFadeImages": CrossFadeImages, "EmptyLatentImagePresets": EmptyLatentImagePresets, "ColorMatch": ColorMatch, - "GetImageRangeFromBatch": GetImageRangeFromBatch + "GetImageRangeFromBatch": GetImageRangeFromBatch, + "SaveImageWithAlpha": SaveImageWithAlpha } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -1040,5 +1119,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SomethingToString": "SomethingToString", "EmptyLatentImagePresets": "EmptyLatentImagePresets", "ColorMatch": "ColorMatch", - "GetImageRangeFromBatch": "GetImageRangeFromBatch" + "GetImageRangeFromBatch": "GetImageRangeFromBatch", + "SaveImageWithAlpha": "SaveImageWithAlpha" } \ No newline at end of file