Add LoadAndResizeImage

This commit is contained in:
Kijai 2024-05-14 15:20:03 +03:00
parent f741ef0252
commit 7e6bd8d14a
2 changed files with 128 additions and 2 deletions

View File

@ -61,6 +61,7 @@ NODE_CONFIG = {
"ImageResizeKJ": {"class": ImageResizeKJ, "name": "Resize Image"}, "ImageResizeKJ": {"class": ImageResizeKJ, "name": "Resize Image"},
"ImageUpscaleWithModelBatched": {"class": ImageUpscaleWithModelBatched, "name": "Image Upscale With Model Batched"}, "ImageUpscaleWithModelBatched": {"class": ImageUpscaleWithModelBatched, "name": "Image Upscale With Model Batched"},
"InsertImagesToBatchIndexed": {"class": InsertImagesToBatchIndexed, "name": "Insert Images To Batch Indexed"}, "InsertImagesToBatchIndexed": {"class": InsertImagesToBatchIndexed, "name": "Insert Images To Batch Indexed"},
"LoadAndResizeImage": {"class": LoadAndResizeImage, "name": "Load & Resize Image"},
"MergeImageChannels": {"class": MergeImageChannels, "name": "Merge Image Channels"}, "MergeImageChannels": {"class": MergeImageChannels, "name": "Merge Image Channels"},
"PreviewAnimation": {"class": PreviewAnimation, "name": "Preview Animation"}, "PreviewAnimation": {"class": PreviewAnimation, "name": "Preview Animation"},
"RemapImageRange": {"class": RemapImageRange, "name": "Remap Image Range"}, "RemapImageRange": {"class": RemapImageRange, "name": "Remap Image Range"},

View File

@ -7,7 +7,8 @@ import math
import os import os
import re import re
import json import json
from PIL import ImageGrab, ImageDraw, ImageFont, Image import hashlib
from PIL import ImageGrab, ImageDraw, ImageFont, Image, ImageSequence, ImageOps
from nodes import MAX_RESOLUTION, SaveImage from nodes import MAX_RESOLUTION, SaveImage
from comfy_extras.nodes_mask import ImageCompositeMasked from comfy_extras.nodes_mask import ImageCompositeMasked
@ -15,6 +16,7 @@ from comfy.cli_args import args
from comfy.utils import ProgressBar, common_upscale from comfy.utils import ProgressBar, common_upscale
import folder_paths import folder_paths
import model_management import model_management
import node_helpers
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@ -1290,6 +1292,14 @@ class ImageResizeKJ:
CATEGORY = "KJNodes/image" CATEGORY = "KJNodes/image"
DESCRIPTION = """ DESCRIPTION = """
Resizes the image to the specified width and height. Resizes the image to the specified width and height.
Size can be retrieved from the inputs, and the final scale
is determined in this order of importance:
- get_image_size
- width_input and height_input
- width and height widgets
Keep proportions keeps the aspect ratio of the image, by
highest dimension.
""" """
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, width_input=None, height_input=None, get_image_size=None): def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, width_input=None, height_input=None, get_image_size=None):
@ -1319,4 +1329,119 @@ Resizes the image to the specified width and height.
scaled = common_upscale(image, width, height, upscale_method, 'disabled') scaled = common_upscale(image, width, height, upscale_method, 'disabled')
scaled = scaled.movedim(1,-1) scaled = scaled.movedim(1,-1)
return(scaled, scaled.shape[2], scaled.shape[1],) return(scaled, scaled.shape[2], scaled.shape[1],)
class LoadAndResizeImage:
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{
"image": (sorted(files), {"image_upload": True}),
"resize": ("BOOLEAN", { "default": False }),
"width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
"height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
"repeat": ("INT", { "default": 1, "min": 1, "max": 4096, "step": 1, }),
"keep_proportion": ("BOOLEAN", { "default": False }),
"divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
"mask_channel": (s._color_channels, ),
},
}
CATEGORY = "KJNodes/image"
RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT",)
RETURN_NAMES = ("image", "mask", "width", "height",)
FUNCTION = "load_image"
def load_image(self, image, resize, width, height, repeat, keep_proportion, divisible_by, mask_channel):
image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
W, H = img.size
if resize:
if keep_proportion:
ratio = min(width / W, height / H)
width = round(W * ratio)
height = round(H * ratio)
else:
if width == 0:
width = W
if height == 0:
height = H
if divisible_by > 1:
width = width - (width % divisible_by)
height = height - (height % divisible_by)
else:
width, height = W, H
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
if resize:
image = image.resize((width, height), Image.Resampling.BILINEAR)
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
mask = None
c = mask_channel[0].upper()
if c in i.getbands():
if resize:
i = i.resize((width, height), Image.Resampling.BILINEAR)
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)
if c == 'A':
mask = 1. - mask
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
if repeat > 1:
output_image = output_image.repeat(repeat, 1, 1, 1)
output_mask = output_mask.repeat(repeat, 1, 1)
return (output_image, output_mask, width, height)
@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True