Fix the nodes to actually work

This commit is contained in:
kijai 2023-10-07 23:51:09 +03:00
parent d526683f25
commit d566c3c35a

View File

@ -63,14 +63,17 @@ class CreateGradientMask:
offset_gradient = gradient - time # Offset the gradient values based on time
image_batch[i] = offset_gradient.reshape(1, -1)
output = torch.from_numpy(image_batch)
out.append(output)
mask = output
print("gradientmaskshape")
print(mask.shape)
out.append(mask)
if invert:
return (1.0 - torch.stack(out, dim=0),)
return (torch.stack(out, dim=0),)
return (1.0 - torch.cat(out, dim=0),)
return (torch.cat(out, dim=0),)
class CreateTextMask:
RETURN_TYPES = ("MASK",)
RETURN_TYPES = ("IMAGE", "MASK",)
FUNCTION = "createtextmask"
CATEGORY = "KJNodes"
@ -96,10 +99,10 @@ class CreateTextMask:
# Define the number of images in the batch
batch_size = frames
out = []
masks = []
rotation = start_rotation
rotation_increment = (end_rotation - start_rotation) / (batch_size - 1)
# Create an empty array to store the image batch
image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
if frames > 1:
rotation_increment = (end_rotation - start_rotation) / (batch_size - 1)
if font_path == "fonts\\TTNorms-Black.otf": #I don't know why relative path won't work otherwise...
font_path = os.path.join(script_dir, font_path)
# Generate the text
@ -112,14 +115,15 @@ class CreateTextMask:
text_center_y = text_y + text_height / 2
draw.text((text_x, text_y), text, font=font, fill="white")
image = image.rotate(rotation, center=(text_center_x, text_center_y))
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
mask = image[:, :, :, 0]
masks.append(mask)
out.append(image)
rotation += rotation_increment
image_batch[i] = np.array(image.convert("L"))
output = torch.from_numpy(image_batch)
rotation += 10
out.append(output)
if invert:
return (1.0 - torch.stack(out, dim=0),)
return (torch.stack(out, dim=0),)
return (1.0 - torch.cat(out, dim=0),)
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
class GrowMaskWithBlur:
@classmethod
@ -173,6 +177,7 @@ class GrowMaskWithBlur:
else:
expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change
output = torch.from_numpy(output)
print(output.shape)
out.append(output)
blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)