diff --git a/nodes.py b/nodes.py index 03f9208..cedb7b5 100644 --- a/nodes.py +++ b/nodes.py @@ -298,6 +298,88 @@ class CreateFadeMask: return (1.0 - torch.cat(out, dim=0),) return (torch.cat(out, dim=0),) +class CreateFadeMaskAdvanced: + + RETURN_TYPES = ("MASK",) + FUNCTION = "createfademask" + CATEGORY = "KJNodes/masking/generate" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n16:(0.0)\n", "multiline": True}), + "invert": ("BOOLEAN", {"default": False}), + "frames": ("INT", {"default": 16,"min": 2, "max": 255, "step": 1}), + "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],), + }, + } + + def createfademask(self, frames, width, height, invert, points_string, interpolation): + def ease_in(t): + return t * t + + def ease_out(t): + return 1 - (1 - t) * (1 - t) + + def ease_in_out(t): + return 3 * t * t - 2 * t * t * t + + # Parse the input string into a list of tuples + points = [] + points_string = points_string.rstrip(',\n') + for point_str in points_string.split(','): + frame_str, color_str = point_str.split(':') + frame = int(frame_str.strip()) + color = float(color_str.strip()[1:-1]) # Remove parentheses around color + points.append((frame, color)) + + # Check if the last frame is already in the points + if len(points) == 0 or points[-1][0] != frames - 1: + # If not, add it with the color of the last specified frame + points.append((frames - 1, points[-1][1] if points else 0)) + + # Sort the points by frame number + points.sort(key=lambda x: x[0]) + + batch_size = frames + out = [] + image_batch = np.zeros((batch_size, height, width), dtype=np.float32) + + # Index of the next point to interpolate towards + next_point = 1 + + for i in range(batch_size): + while next_point < len(points) and i > points[next_point][0]: + next_point += 1 + + # Interpolate between the previous point and the next point + prev_point = next_point - 1 + t = (i - points[prev_point][0]) / (points[next_point][0] - points[prev_point][0]) + if interpolation == "ease_in": + t = ease_in(t) + elif interpolation == "ease_out": + t = ease_out(t) + elif interpolation == "ease_in_out": + t = ease_in_out(t) + elif interpolation == "linear": + pass # No need to modify `t` for linear interpolation + + color = points[prev_point][1] - t * (points[prev_point][1] - points[next_point][1]) + color = np.clip(color, 0, 255) + image = np.full((height, width), color, dtype=np.float32) + image_batch[i] = image + + output = torch.from_numpy(image_batch) + mask = output + out.append(mask) + + if invert: + return (1.0 - torch.cat(out, dim=0),) + return (torch.cat(out, dim=0),) + class CrossFadeImages: RETURN_TYPES = ("IMAGE",) @@ -2323,6 +2405,7 @@ NODE_CLASS_MAPPINGS = { "CreateTextMask": CreateTextMask, "CreateAudioMask": CreateAudioMask, "CreateFadeMask": CreateFadeMask, + "CreateFadeMaskAdvanced": CreateFadeMaskAdvanced, "CreateFluidMask" :CreateFluidMask, "VRAM_Debug" : VRAM_Debug, "SomethingToString" : SomethingToString, @@ -2364,6 +2447,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CreateGradientMask": "CreateGradientMask", "CreateTextMask" : "CreateTextMask", "CreateFadeMask" : "CreateFadeMask", + "CreateFadeMaskAdvanced" : "CreateFadeMaskAdvanced", "CreateFluidMask" : "CreateFluidMask", "VRAM_Debug" : "VRAM Debug", "CrossFadeImages": "CrossFadeImages", diff --git a/requirements.txt b/requirements.txt index 1afd2c1..e93dc99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ numpy Pillow scipy color-matcher +matplotlib