Update nodes.py

This commit is contained in:
kijai 2023-10-28 14:17:48 +03:00
parent 53a16abc23
commit 888bfa5cff

View File

@ -279,6 +279,7 @@ class CrossFadeImages:
"images_1": ("IMAGE",), "images_1": ("IMAGE",),
"images_2": ("IMAGE",), "images_2": ("IMAGE",),
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],), "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
"batch_size": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
"transition_start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), "transition_start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
"transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
"start_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}), "start_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}),
@ -286,7 +287,7 @@ class CrossFadeImages:
}, },
} }
def crossfadeimages(self, images_1, images_2, transition_start_index, transitioning_frames, interpolation, start_level, end_level): def crossfadeimages(self, images_1, images_2, transition_start_index, transitioning_frames, interpolation, batch_size, start_level, end_level):
def crossfade(images_1, images_2, alpha): def crossfade(images_1, images_2, alpha):
crossfade = (1 - alpha) * images_1 + alpha * images_2 crossfade = (1 - alpha) * images_1 + alpha * images_2
@ -320,10 +321,8 @@ class CrossFadeImages:
"exponential_ease_out": exponential_ease_out, "exponential_ease_out": exponential_ease_out,
} }
#batch_size = images_1.size(0)
batch_size = images_1.size(0)
crossfade_images = [] crossfade_images = []
#transition_frame_length = int(batch_size / transitioning_frames)
alphas = torch.linspace(start_level, end_level, batch_size) alphas = torch.linspace(start_level, end_level, batch_size)
for i in range(batch_size): for i in range(batch_size):