Update CrossFadeImages to allow joining different batch sizes

This commit is contained in:
kijai 2025-07-14 15:24:05 +03:00
parent 9ea455afd6
commit 4549812bcc

View File

@ -1265,7 +1265,7 @@ class CrossFadeImages:
"images_1": ("IMAGE",), "images_1": ("IMAGE",),
"images_2": ("IMAGE",), "images_2": ("IMAGE",),
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],), "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
"transition_start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), "transition_start_index": ("INT", {"default": 1,"min": -4096, "max": 4096, "step": 1}),
"transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
"start_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}), "start_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}),
"end_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}), "end_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}),
@ -1308,36 +1308,36 @@ class CrossFadeImages:
crossfade_images = [] crossfade_images = []
if transition_start_index < 0:
transition_start_index = len(images_1) + transition_start_index
if transition_start_index < 0:
raise ValueError("Transition start index is out of range for images_1.")
transitioning_frames = min(transitioning_frames, len(images_1) - transition_start_index, len(images_2))
alphas = torch.linspace(start_level, end_level, transitioning_frames) alphas = torch.linspace(start_level, end_level, transitioning_frames)
for i in range(transitioning_frames): for i in range(transitioning_frames):
alpha = alphas[i] alpha = alphas[i]
image1 = images_1[i + transition_start_index] image1 = images_1[transition_start_index + i]
image2 = images_2[i + transition_start_index] image2 = images_2[i]
easing_function = easing_functions.get(interpolation) easing_function = easing_functions.get(interpolation)
alpha = easing_function(alpha) # Apply the easing function to the alpha value alpha = easing_function(alpha) # Apply the easing function to the alpha value
crossfade_image = crossfade(image1, image2, alpha) crossfade_image = crossfade(image1, image2, alpha)
crossfade_images.append(crossfade_image) crossfade_images.append(crossfade_image)
# Convert crossfade_images to tensor # Convert crossfade_images to tensor
crossfade_images = torch.stack(crossfade_images, dim=0) crossfade_images = torch.stack(crossfade_images, dim=0)
# Get the last frame result of the interpolation
last_frame = crossfade_images[-1]
# Calculate the number of remaining frames from images_2
remaining_frames = len(images_2) - (transition_start_index + transitioning_frames)
# Crossfade the remaining frames with the last used alpha value
for i in range(remaining_frames):
alpha = alphas[-1]
image1 = images_1[i + transition_start_index + transitioning_frames]
image2 = images_2[i + transition_start_index + transitioning_frames]
easing_function = easing_functions.get(interpolation)
alpha = easing_function(alpha) # Apply the easing function to the alpha value
crossfade_image = crossfade(image1, image2, alpha) # Append the beginning of images_1 (before the transition)
crossfade_images = torch.cat([crossfade_images, crossfade_image.unsqueeze(0)], dim=0)
# Append the beginning of images_1
beginning_images_1 = images_1[:transition_start_index] beginning_images_1 = images_1[:transition_start_index]
crossfade_images = torch.cat([beginning_images_1, crossfade_images], dim=0) crossfade_images = torch.cat([beginning_images_1, crossfade_images], dim=0)
# Append the remaining frames of images_2 (after the transition)
remaining_images_2 = images_2[transitioning_frames:]
if len(remaining_images_2) > 0:
crossfade_images = torch.cat([crossfade_images, remaining_images_2], dim=0)
return (crossfade_images, ) return (crossfade_images, )
class CrossFadeImagesMulti: class CrossFadeImagesMulti: