LoadImagesFromFolderKJ: resize batch to largest image in the folder

This commit is contained in:
kijai 2025-01-31 12:24:40 +02:00
parent 2abf557e3d
commit 75a990b40d
3 changed files with 40 additions and 28 deletions

View File

@ -2182,8 +2182,8 @@ class LoadImagesFromFolderKJ:
RETURN_TYPES = ("IMAGE", "MASK", "INT", "STRING",)
RETURN_NAMES = ("image", "mask", "count", "image_path",)
FUNCTION = "load_images"
CATEGORY = "image"
CATEGORY = "KJNodes/image"
DESCRIPTION = """Loads images from a folder into a batch, images are resized to match the largest image in the folder."""
def load_images(self, folder, image_load_cap, start_index):
if not os.path.isdir(folder):
@ -2202,6 +2202,17 @@ class LoadImagesFromFolderKJ:
# start at start_index
dir_files = dir_files[start_index:]
# First pass - find maximum dimensions
max_width = 0
max_height = 0
for image_path in dir_files:
if os.path.isdir(image_path):
continue
with Image.open(image_path) as img:
width, height = img.size
max_width = max(max_width, width)
max_height = max(max_height, height)
images = []
masks = []
image_path_list = []
@ -2211,24 +2222,33 @@ class LoadImagesFromFolderKJ:
limit_images = True
image_count = 0
has_non_empty_mask = False
for image_path in dir_files:
if os.path.isdir(image_path) and os.path.ex:
if os.path.isdir(image_path):
continue
if limit_images and image_count >= image_load_cap:
break
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
# Resize image to maximum dimensions
if i.size != (max_width, max_height):
i = i.resize((max_width, max_height), Image.Resampling.LANCZOS)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
has_non_empty_mask = True
if mask.shape != (max_height, max_width):
mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0),
size=(max_height, max_width),
mode='bilinear',
align_corners=False).squeeze()
else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
mask = torch.zeros((max_height, max_width), dtype=torch.float32, device="cpu")
images.append(image)
masks.append(mask)
image_path_list.append(image_path)
@ -2236,30 +2256,16 @@ class LoadImagesFromFolderKJ:
if len(images) == 1:
return (images[0], masks[0], 1, image_path_list)
elif len(images) > 1:
image1 = images[0]
mask1 = None
mask1 = masks[0].unsqueeze(0)
for image2 in images[1:]:
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(image2.movedim(-1, 1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1, -1)
image1 = torch.cat((image1, image2), dim=0)
for mask2 in masks[1:]:
if has_non_empty_mask:
if image1.shape[1:3] != mask2.shape:
mask2 = torch.nn.functional.interpolate(mask2.unsqueeze(0).unsqueeze(0), size=(image1.shape[2], image1.shape[1]), mode='bilinear', align_corners=False)
mask2 = mask2.squeeze(0)
else:
mask2 = mask2.unsqueeze(0)
else:
mask2 = mask2.unsqueeze(0)
if mask1 is None:
mask1 = mask2
else:
mask1 = torch.cat((mask1, mask2), dim=0)
mask1 = torch.cat((mask1, mask2.unsqueeze(0)), dim=0)
return (image1, mask1, len(images), image_path_list)

View File

@ -166,10 +166,11 @@ Selects and returns the latents at the specified indices as an latent batch.
"required": {
"latents": ("LATENT",),
"indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
"latent_format": (["BCHW", "BTCHW", "BCTHW"], {"default": "BCHW"}),
},
}
def indexedlatentsfrombatch(self, latents, indexes):
def indexedlatentsfrombatch(self, latents, indexes, latent_format):
samples = latents.copy()
latent_samples = samples["samples"]
@ -181,7 +182,12 @@ Selects and returns the latents at the specified indices as an latent batch.
indices_tensor = torch.tensor(index_list, dtype=torch.long)
# Select the latents at the specified indices
chosen_latents = latent_samples[indices_tensor]
if latent_format == "BCHW":
chosen_latents = latent_samples[indices_tensor]
elif latent_format == "BTCHW":
chosen_latents = latent_samples[:, indices_tensor]
elif latent_format == "BCTHW":
chosen_latents = latent_samples[:, :, indices_tensor]
samples["samples"] = chosen_latents
return (samples,)
@ -2216,4 +2222,5 @@ Concatenates the audio1 to audio2 in the specified direction.
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) # Concatenate along width
elif direction == 'left':
concatenated_audio= torch.cat((waveform_2, waveform_1), dim=2) # Concatenate along width
return ({"waveform": concatenated_audio, "sample_rate": sample_rate_1},)
return ({"waveform": concatenated_audio, "sample_rate": sample_rate_1},)

View File

@ -95,7 +95,6 @@ app.registerExtension({
app.ui.settings.addSetting({
id: "KJNodes.nodeAutoColor",
name: "KJNodes: Automatically set node colors",
defaultValue: true,
type: "boolean",
defaultValue: true,
});