diff --git a/__init__.py b/__init__.py index a9b94c9..f1ce49c 100644 --- a/__init__.py +++ b/__init__.py @@ -43,6 +43,7 @@ NODE_CONFIG = { "AddLabel": {"class": AddLabel, "name": "Add Label"}, "ColorMatch": {"class": ColorMatch, "name": "Color Match"}, "CrossFadeImages": {"class": CrossFadeImages, "name": "Cross Fade Images"}, + "CrossFadeImagesMulti": {"class": CrossFadeImagesMulti, "name": "Cross Fade Images Multi"}, "GetImagesFromBatchIndexed": {"class": GetImagesFromBatchIndexed, "name": "Get Images From Batch Indexed"}, "GetImageRangeFromBatch": {"class": GetImageRangeFromBatch, "name": "Get Image or Mask Range From Batch"}, "GetImageSizeAndCount": {"class": GetImageSizeAndCount, "name": "Get Image Size & Count"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 803df74..369f82e 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1224,6 +1224,88 @@ class CrossFadeImages: beginning_images_1 = images_1[:transition_start_index] crossfade_images = torch.cat([beginning_images_1, crossfade_images], dim=0) return (crossfade_images, ) + +class CrossFadeImagesMulti: + RETURN_TYPES = ("IMAGE",) + FUNCTION = "crossfadeimages" + CATEGORY = "KJNodes/image" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}), + "image_1": ("IMAGE",), + "image_2": ("IMAGE",), + "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],), + "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), + }, + } + + def crossfadeimages(self, inputcount, transitioning_frames, interpolation, **kwargs): + + def crossfade(images_1, images_2, alpha): + crossfade = (1 - alpha) * images_1 + alpha * images_2 + return crossfade + def ease_in(t): + return t * t + def ease_out(t): + return 1 - (1 - t) * (1 - t) + def ease_in_out(t): + return 3 * t * t - 2 * t * t * t + def bounce(t): + if t < 0.5: + return self.ease_out(t * 2) * 0.5 + else: + return self.ease_in((t - 0.5) * 2) * 0.5 + 0.5 + def elastic(t): + return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1)) + def glitchy(t): + return t + 0.1 * math.sin(40 * t) + def exponential_ease_out(t): + return 1 - (1 - t) ** 4 + + easing_functions = { + "linear": lambda t: t, + "ease_in": ease_in, + "ease_out": ease_out, + "ease_in_out": ease_in_out, + "bounce": bounce, + "elastic": elastic, + "glitchy": glitchy, + "exponential_ease_out": exponential_ease_out, + } + + image_1 = kwargs["image_1"] + height = image_1.shape[1] + width = image_1.shape[2] + + easing_function = easing_functions[interpolation] + + for c in range(1, inputcount): + frames = [] + new_image = kwargs[f"image_{c + 1}"] + new_image_height = new_image.shape[1] + new_image_width = new_image.shape[2] + + if new_image_height != height or new_image_width != width: + new_image = common_upscale(new_image.movedim(-1, 1), width, height, "lanczos", "disabled") + new_image = new_image.movedim(1, -1) # Move channels back to the last dimension + + last_frame_image_1 = image_1[-1] + first_frame_image_2 = new_image[0] + + for frame in range(transitioning_frames): + t = frame / (transitioning_frames - 1) + alpha = easing_function(t) + alpha_tensor = torch.tensor(alpha, dtype=last_frame_image_1.dtype, device=last_frame_image_1.device) + frame_image = crossfade(last_frame_image_1, first_frame_image_2, alpha_tensor) + frames.append(frame_image) + + frames = torch.stack(frames) + image_1 = torch.cat((image_1, frames, new_image), dim=0) + + return image_1, class GetImageRangeFromBatch: diff --git a/web/js/jsnodes.js b/web/js/jsnodes.js index aa71369..5f7936f 100644 --- a/web/js/jsnodes.js +++ b/web/js/jsnodes.js @@ -32,6 +32,7 @@ app.registerExtension({ case "ImageBatchMulti": case "ImageAddMulti": case "ImageConcatMulti": + case "CrossFadeImagesMulti": nodeType.prototype.onNodeCreated = function () { this._type = "IMAGE" this.inputs_offset = nodeData.name.includes("selective")?1:0