diff --git a/nodes.py b/nodes.py index 23a4308..d3a2b5c 100644 --- a/nodes.py +++ b/nodes.py @@ -243,10 +243,11 @@ class CreateFadeMask: "start_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}), "midpoint_level": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}), "end_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}), + "midpoint_frame": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}), }, } - def createfademask(self, frames, width, height, invert, interpolation, start_level, midpoint_level, end_level): + def createfademask(self, frames, width, height, invert, interpolation, start_level, midpoint_level, end_level, midpoint_frame=None): def ease_in(t): return t * t @@ -259,32 +260,38 @@ class CreateFadeMask: batch_size = frames out = [] image_batch = np.zeros((batch_size, height, width), dtype=np.float32) - + + if midpoint_frame is None: + midpoint_frame = batch_size // 2 + for i in range(batch_size): - t = i / (batch_size - 1) - - if interpolation == "ease_in": - t = ease_in(t) - elif interpolation == "ease_out": - t = ease_out(t) - elif interpolation == "ease_in_out": - t = ease_in_out(t) - - if midpoint_level is not None: - if t < 0.5: - color = start_level - t * (start_level - midpoint_level) * 2 - else: - color = midpoint_level - (t - 0.5) * (midpoint_level - end_level) * 2 + if i <= midpoint_frame: + t = i / midpoint_frame + if interpolation == "ease_in": + t = ease_in(t) + elif interpolation == "ease_out": + t = ease_out(t) + elif interpolation == "ease_in_out": + t = ease_in_out(t) + color = start_level - t * (start_level - midpoint_level) else: - color = start_level - t * (start_level - end_level) - + t = (i - midpoint_frame) / (batch_size - midpoint_frame) + if interpolation == "ease_in": + t = ease_in(t) + elif interpolation == "ease_out": + t = ease_out(t) + elif interpolation == "ease_in_out": + t = ease_in_out(t) + color = midpoint_level - t * (midpoint_level - end_level) + + color = np.clip(color, 0, 255) image = np.full((height, width), color, dtype=np.float32) image_batch[i] = image - + output = torch.from_numpy(image_batch) mask = output out.append(mask) - + if invert: return (1.0 - torch.cat(out, dim=0),) return (torch.cat(out, dim=0),)