Add PadImageBatchInterleaved

This commit is contained in:
kijai 2025-04-02 02:07:47 +03:00
parent 0addfc6101
commit 59bd92ff49
2 changed files with 54 additions and 1 deletions

View File

@ -82,6 +82,7 @@ NODE_CONFIG = {
"LoadAndResizeImage": {"class": LoadAndResizeImage, "name": "Load & Resize Image"},
"LoadImagesFromFolderKJ": {"class": LoadImagesFromFolderKJ, "name": "Load Images From Folder (KJ)"},
"MergeImageChannels": {"class": MergeImageChannels, "name": "Merge Image Channels"},
"PadImageBatchInterleaved": {"class": PadImageBatchInterleaved, "name": "Pad Image Batch Interleaved"},
"PreviewAnimation": {"class": PreviewAnimation, "name": "Preview Animation"},
"RemapImageRange": {"class": RemapImageRange, "name": "Remap Image Range"},
"ReverseImageBatch": {"class": ReverseImageBatch, "name": "Reverse Image Batch"},

View File

@ -1122,6 +1122,9 @@ class ImagePrepForICLora:
print("Warning: The incoming mask is fully black. Handling it as None.")
reference_mask = None
image = reference_image
if latent_image is not None:
if image.shape[0] != latent_image.shape[0]:
image = image.repeat(latent_image.shape[0], 1, 1, 1)
B, H, W, C = image.size()
# Handle mask
@ -1849,6 +1852,52 @@ Inserts images at the specified indices into the original image batch.
return (original_images,)
class PadImageBatchInterleaved:
RETURN_TYPES = ("IMAGE", "MASK",)
RETURN_NAMES = ("images", "masks",)
FUNCTION = "pad"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Inserts empty frames between the images in a batch.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
"empty_frames_per_image": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
},
}
def pad(self, images, empty_frames_per_image):
B, H, W, C = images.shape
if B == 1 or empty_frames_per_image == 0:
return (images,)
# Original B images + (B-1) sets of empty frames between them
total_frames = B + (B-1) * empty_frames_per_image
# Create new tensor with zeros (empty frames)
padded_batch = torch.zeros((total_frames, H, W, C),
dtype=images.dtype,
device=images.device)
# Create mask tensor (1 for original frames, 0 for empty frames)
mask = torch.zeros((total_frames, H, W),
dtype=images.dtype,
device=images.device)
# Fill in original images at their new positions
for i in range(B):
# Each image is separated by empty_frames_per_image blank frames
new_pos = i * (empty_frames_per_image + 1)
padded_batch[new_pos] = images[i]
mask[new_pos] = 1.0 # Mark this as an original frame
return (padded_batch, mask)
class ReplaceImagesInBatch:
RETURN_TYPES = ("IMAGE",)
@ -3147,11 +3196,14 @@ class ImagePadKJ:
out_image[b, :, :, :] = bg_color.unsqueeze(0).unsqueeze(0) # Expand for H and W dimensions
out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
if mask is not None:
out_masks = torch.zeros((BM, padded_height, padded_width), dtype=mask.dtype, device=mask.device)
for m in range(BM):
out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = mask[m]
else:
out_masks = torch.zeros((1, padded_height, padded_width), dtype=image.dtype, device=image.device)
out_masks = torch.ones((B, padded_height, padded_width), dtype=image.dtype, device=image.device)
for m in range(B):
out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = 0.0
return (out_image, out_masks)