Add CreateFadeMaskAdvanced

This commit is contained in:
kijai 2023-11-21 15:41:20 +02:00
parent 4a2ae659c1
commit 8e3b2abb20
2 changed files with 85 additions and 0 deletions

View File

@ -298,6 +298,88 @@ class CreateFadeMask:
return (1.0 - torch.cat(out, dim=0),)
return (torch.cat(out, dim=0),)
class CreateFadeMaskAdvanced:
RETURN_TYPES = ("MASK",)
FUNCTION = "createfademask"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n16:(0.0)\n", "multiline": True}),
"invert": ("BOOLEAN", {"default": False}),
"frames": ("INT", {"default": 16,"min": 2, "max": 255, "step": 1}),
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
},
}
def createfademask(self, frames, width, height, invert, points_string, interpolation):
def ease_in(t):
return t * t
def ease_out(t):
return 1 - (1 - t) * (1 - t)
def ease_in_out(t):
return 3 * t * t - 2 * t * t * t
# Parse the input string into a list of tuples
points = []
points_string = points_string.rstrip(',\n')
for point_str in points_string.split(','):
frame_str, color_str = point_str.split(':')
frame = int(frame_str.strip())
color = float(color_str.strip()[1:-1]) # Remove parentheses around color
points.append((frame, color))
# Check if the last frame is already in the points
if len(points) == 0 or points[-1][0] != frames - 1:
# If not, add it with the color of the last specified frame
points.append((frames - 1, points[-1][1] if points else 0))
# Sort the points by frame number
points.sort(key=lambda x: x[0])
batch_size = frames
out = []
image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
# Index of the next point to interpolate towards
next_point = 1
for i in range(batch_size):
while next_point < len(points) and i > points[next_point][0]:
next_point += 1
# Interpolate between the previous point and the next point
prev_point = next_point - 1
t = (i - points[prev_point][0]) / (points[next_point][0] - points[prev_point][0])
if interpolation == "ease_in":
t = ease_in(t)
elif interpolation == "ease_out":
t = ease_out(t)
elif interpolation == "ease_in_out":
t = ease_in_out(t)
elif interpolation == "linear":
pass # No need to modify `t` for linear interpolation
color = points[prev_point][1] - t * (points[prev_point][1] - points[next_point][1])
color = np.clip(color, 0, 255)
image = np.full((height, width), color, dtype=np.float32)
image_batch[i] = image
output = torch.from_numpy(image_batch)
mask = output
out.append(mask)
if invert:
return (1.0 - torch.cat(out, dim=0),)
return (torch.cat(out, dim=0),)
class CrossFadeImages:
RETURN_TYPES = ("IMAGE",)
@ -2323,6 +2405,7 @@ NODE_CLASS_MAPPINGS = {
"CreateTextMask": CreateTextMask,
"CreateAudioMask": CreateAudioMask,
"CreateFadeMask": CreateFadeMask,
"CreateFadeMaskAdvanced": CreateFadeMaskAdvanced,
"CreateFluidMask" :CreateFluidMask,
"VRAM_Debug" : VRAM_Debug,
"SomethingToString" : SomethingToString,
@ -2364,6 +2447,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CreateGradientMask": "CreateGradientMask",
"CreateTextMask" : "CreateTextMask",
"CreateFadeMask" : "CreateFadeMask",
"CreateFadeMaskAdvanced" : "CreateFadeMaskAdvanced",
"CreateFluidMask" : "CreateFluidMask",
"VRAM_Debug" : "VRAM Debug",
"CrossFadeImages": "CrossFadeImages",

View File

@ -3,3 +3,4 @@ numpy
Pillow
scipy
color-matcher
matplotlib