diff --git a/__init__.py b/__init__.py index fce7a50..677d6a3 100644 --- a/__init__.py +++ b/__init__.py @@ -55,6 +55,7 @@ NODE_CONFIG = { "ImageBatchFilter": {"class": ImageBatchFilter, "name": "Image Batch Filter"}, "ImageAndMaskPreview": {"class": ImageAndMaskPreview}, "ImageAddMulti": {"class": ImageAddMulti, "name": "Image Add Multi"}, + "ImageBatchJoinWithTransition": {"class": ImageBatchJoinWithTransition, "name": "Image Batch Join With Transition"}, "ImageBatchMulti": {"class": ImageBatchMulti, "name": "Image Batch Multi"}, "ImageBatchRepeatInterleaving": {"class": ImageBatchRepeatInterleaving}, "ImageBatchTestPattern": {"class": ImageBatchTestPattern, "name": "Image Batch Test Pattern"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 70a8067..14fe31b 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1251,7 +1251,39 @@ nodes for example. if pass_through: return (preview, ) return(self.save_images(preview, filename_prefix, prompt, extra_pnginfo)) - + +def crossfade(images_1, images_2, alpha): + crossfade = (1 - alpha) * images_1 + alpha * images_2 + return crossfade +def ease_in(t): + return t * t +def ease_out(t): + return 1 - (1 - t) * (1 - t) +def ease_in_out(t): + return 3 * t * t - 2 * t * t * t +def bounce(t): + if t < 0.5: + return ease_out(t * 2) * 0.5 + else: + return ease_in((t - 0.5) * 2) * 0.5 + 0.5 +def elastic(t): + return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1)) +def glitchy(t): + return t + 0.1 * math.sin(40 * t) +def exponential_ease_out(t): + return 1 - (1 - t) ** 4 + +easing_functions = { + "linear": lambda t: t, + "ease_in": ease_in, + "ease_out": ease_out, + "ease_in_out": ease_in_out, + "bounce": bounce, + "elastic": elastic, + "glitchy": glitchy, + "exponential_ease_out": exponential_ease_out, +} + class CrossFadeImages: RETURN_TYPES = ("IMAGE",) @@ -1274,38 +1306,6 @@ class CrossFadeImages: def crossfadeimages(self, images_1, images_2, transition_start_index, transitioning_frames, interpolation, start_level, end_level): - def crossfade(images_1, images_2, alpha): - crossfade = (1 - alpha) * images_1 + alpha * images_2 - return crossfade - def ease_in(t): - return t * t - def ease_out(t): - return 1 - (1 - t) * (1 - t) - def ease_in_out(t): - return 3 * t * t - 2 * t * t * t - def bounce(t): - if t < 0.5: - return ease_out(t * 2) * 0.5 - else: - return ease_in((t - 0.5) * 2) * 0.5 + 0.5 - def elastic(t): - return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1)) - def glitchy(t): - return t + 0.1 * math.sin(40 * t) - def exponential_ease_out(t): - return 1 - (1 - t) ** 4 - - easing_functions = { - "linear": lambda t: t, - "ease_in": ease_in, - "ease_out": ease_out, - "ease_in_out": ease_in_out, - "bounce": bounce, - "elastic": elastic, - "glitchy": glitchy, - "exponential_ease_out": exponential_ease_out, - } - crossfade_images = [] if transition_start_index < 0: @@ -1359,38 +1359,6 @@ class CrossFadeImagesMulti: def crossfadeimages(self, inputcount, transitioning_frames, interpolation, **kwargs): - def crossfade(images_1, images_2, alpha): - crossfade = (1 - alpha) * images_1 + alpha * images_2 - return crossfade - def ease_in(t): - return t * t - def ease_out(t): - return 1 - (1 - t) * (1 - t) - def ease_in_out(t): - return 3 * t * t - 2 * t * t * t - def bounce(t): - if t < 0.5: - return self.ease_out(t * 2) * 0.5 - else: - return self.ease_in((t - 0.5) * 2) * 0.5 + 0.5 - def elastic(t): - return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1)) - def glitchy(t): - return t + 0.1 * math.sin(40 * t) - def exponential_ease_out(t): - return 1 - (1 - t) ** 4 - - easing_functions = { - "linear": lambda t: t, - "ease_in": ease_in, - "ease_out": ease_out, - "ease_in_out": ease_in_out, - "bounce": bounce, - "elastic": elastic, - "glitchy": glitchy, - "exponential_ease_out": exponential_ease_out, - } - image_1 = kwargs["image_1"] height = image_1.shape[1] width = image_1.shape[2] @@ -1423,73 +1391,55 @@ class CrossFadeImagesMulti: return image_1, def transition_images(images_1, images_2, alpha, transition_type, blur_radius, reverse): - width = images_1.shape[1] - height = images_1.shape[0] + width = images_1.shape[1] + height = images_1.shape[0] - mask = torch.zeros_like(images_1, device=images_1.device) - - alpha = alpha.item() - if reverse: - alpha = 1 - alpha + mask = torch.zeros_like(images_1, device=images_1.device) + + alpha = alpha.item() + if reverse: + alpha = 1 - alpha - #transitions from matteo's essential nodes - if "horizontal slide" in transition_type: - pos = round(width * alpha) - mask[:, :pos, :] = 1.0 - elif "vertical slide" in transition_type: - pos = round(height * alpha) - mask[:pos, :, :] = 1.0 - elif "box" in transition_type: - box_w = round(width * alpha) - box_h = round(height * alpha) - x1 = (width - box_w) // 2 - y1 = (height - box_h) // 2 - x2 = x1 + box_w - y2 = y1 + box_h - mask[y1:y2, x1:x2, :] = 1.0 - elif "circle" in transition_type: - radius = math.ceil(math.sqrt(pow(width, 2) + pow(height, 2)) * alpha / 2) - c_x = width // 2 - c_y = height // 2 - x = torch.arange(0, width, dtype=torch.float32, device="cpu") - y = torch.arange(0, height, dtype=torch.float32, device="cpu") - y, x = torch.meshgrid((y, x), indexing="ij") - circle = ((x - c_x) ** 2 + (y - c_y) ** 2) <= (radius ** 2) - mask[circle] = 1.0 - elif "horizontal door" in transition_type: - bar = math.ceil(height * alpha / 2) - if bar > 0: - mask[:bar, :, :] = 1.0 - mask[-bar:,:, :] = 1.0 - elif "vertical door" in transition_type: - bar = math.ceil(width * alpha / 2) - if bar > 0: - mask[:, :bar,:] = 1.0 - mask[:, -bar:,:] = 1.0 - elif "fade" in transition_type: - mask[:, :, :] = alpha + #transitions from matteo's essential nodes + if "horizontal slide" in transition_type: + pos = round(width * alpha) + mask[:, :pos, :] = 1.0 + elif "vertical slide" in transition_type: + pos = round(height * alpha) + mask[:pos, :, :] = 1.0 + elif "box" in transition_type: + box_w = round(width * alpha) + box_h = round(height * alpha) + x1 = (width - box_w) // 2 + y1 = (height - box_h) // 2 + x2 = x1 + box_w + y2 = y1 + box_h + mask[y1:y2, x1:x2, :] = 1.0 + elif "circle" in transition_type: + radius = math.ceil(math.sqrt(pow(width, 2) + pow(height, 2)) * alpha / 2) + c_x = width // 2 + c_y = height // 2 + x = torch.arange(0, width, dtype=torch.float32, device="cpu") + y = torch.arange(0, height, dtype=torch.float32, device="cpu") + y, x = torch.meshgrid((y, x), indexing="ij") + circle = ((x - c_x) ** 2 + (y - c_y) ** 2) <= (radius ** 2) + mask[circle] = 1.0 + elif "horizontal door" in transition_type: + bar = math.ceil(height * alpha / 2) + if bar > 0: + mask[:bar, :, :] = 1.0 + mask[-bar:,:, :] = 1.0 + elif "vertical door" in transition_type: + bar = math.ceil(width * alpha / 2) + if bar > 0: + mask[:, :bar,:] = 1.0 + mask[:, -bar:,:] = 1.0 + elif "fade" in transition_type: + mask[:, :, :] = alpha - mask = gaussian_blur(mask, blur_radius) + mask = gaussian_blur(mask, blur_radius) - return images_1 * (1 - mask) + images_2 * mask - -def ease_in(t): - return t * t -def ease_out(t): - return 1 - (1 - t) * (1 - t) -def ease_in_out(t): - return 3 * t * t - 2 * t * t * t -def bounce(t): - if t < 0.5: - return ease_out(t * 2) * 0.5 - else: - return ease_in((t - 0.5) * 2) * 0.5 + 0.5 -def elastic(t): - return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1)) -def glitchy(t): - return t + 0.1 * math.sin(40 * t) -def exponential_ease_out(t): - return 1 - (1 - t) ** 4 + return images_1 * (1 - mask) + images_2 * mask def gaussian_blur(mask, blur_radius): if blur_radius > 0: @@ -1643,6 +1593,74 @@ Creates transitions between images in a batch. return images.cpu(), +class ImageBatchJoinWithTransition: + RETURN_TYPES = ("IMAGE",) + FUNCTION = "transition_batches" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ +Transitions between two batches of images, starting at a specified index in the first batch. +During the transition, frames from both batches are blended frame-by-frame, so the video keeps playing. +""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images_1": ("IMAGE",), + "images_2": ("IMAGE",), + "start_index": ("INT", {"default": 0, "min": -10000, "max": 10000, "step": 1}), + "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],), + "transition_type": (["horizontal slide", "vertical slide", "box", "circle", "horizontal door", "vertical door", "fade"],), + "transitioning_frames": ("INT", {"default": 1, "min": 1, "max": 4096, "step": 1}), + "blur_radius": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "reverse": ("BOOLEAN", {"default": False}), + "device": (["CPU", "GPU"], {"default": "CPU"}), + }, + } + + def transition_batches(self, images_1, images_2, start_index, interpolation, transition_type, transitioning_frames, blur_radius, reverse, device): + if images_1.shape[0] == 0 or images_2.shape[0] == 0: + raise ValueError("Both input batches must have at least one image.") + + if start_index < 0: + start_index = images_1.shape[0] + start_index + if start_index < 0 or start_index > images_1.shape[0]: + raise ValueError("start_index is out of range.") + + gpu = model_management.get_torch_device() + easing_function = easing_functions[interpolation] + out_frames = [] + + # Add images from images_1 up to start_index + if start_index > 0: + out_frames.append(images_1[:start_index]) + + # Determine how many frames we can blend + max_transition = min(transitioning_frames, images_1.shape[0] - start_index, images_2.shape[0]) + + # Blend corresponding frames from both batches + for i in range(max_transition): + img1 = images_1[start_index + i] + img2 = images_2[i] + if device == "GPU": + img1 = img1.to(gpu) + img2 = img2.to(gpu) + if reverse: + img1, img2 = img2, img1 + t = i / (max_transition - 1) if max_transition > 1 else 1.0 + alpha = easing_function(t) + alpha_tensor = torch.tensor(alpha, dtype=img1.dtype, device=img1.device) + frame_image = transition_images(img1, img2, alpha_tensor, transition_type, blur_radius, reverse) + out_frames.append(frame_image.cpu().unsqueeze(0)) + + # Add remaining images from images_2 after transition + if images_2.shape[0] > max_transition: + out_frames.append(images_2[max_transition:]) + + # Concatenate all frames + out = torch.cat(out_frames, dim=0) + return (out.cpu(),) + class ShuffleImageBatch: RETURN_TYPES = ("IMAGE",) FUNCTION = "shuffle"