diff --git a/nodes.py b/nodes.py index dd09345..65d0094 100644 --- a/nodes.py +++ b/nodes.py @@ -63,14 +63,17 @@ class CreateGradientMask: offset_gradient = gradient - time # Offset the gradient values based on time image_batch[i] = offset_gradient.reshape(1, -1) output = torch.from_numpy(image_batch) - out.append(output) + mask = output + print("gradientmaskshape") + print(mask.shape) + out.append(mask) if invert: - return (1.0 - torch.stack(out, dim=0),) - return (torch.stack(out, dim=0),) + return (1.0 - torch.cat(out, dim=0),) + return (torch.cat(out, dim=0),) class CreateTextMask: - RETURN_TYPES = ("MASK",) + RETURN_TYPES = ("IMAGE", "MASK",) FUNCTION = "createtextmask" CATEGORY = "KJNodes" @@ -96,10 +99,10 @@ class CreateTextMask: # Define the number of images in the batch batch_size = frames out = [] + masks = [] rotation = start_rotation - rotation_increment = (end_rotation - start_rotation) / (batch_size - 1) - # Create an empty array to store the image batch - image_batch = np.zeros((batch_size, height, width), dtype=np.float32) + if frames > 1: + rotation_increment = (end_rotation - start_rotation) / (batch_size - 1) if font_path == "fonts\\TTNorms-Black.otf": #I don't know why relative path won't work otherwise... font_path = os.path.join(script_dir, font_path) # Generate the text @@ -112,14 +115,15 @@ class CreateTextMask: text_center_y = text_y + text_height / 2 draw.text((text_x, text_y), text, font=font, fill="white") image = image.rotate(rotation, center=(text_center_x, text_center_y)) + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + mask = image[:, :, :, 0] + masks.append(mask) + out.append(image) rotation += rotation_increment - image_batch[i] = np.array(image.convert("L")) - output = torch.from_numpy(image_batch) - rotation += 10 - out.append(output) if invert: - return (1.0 - torch.stack(out, dim=0),) - return (torch.stack(out, dim=0),) + return (1.0 - torch.cat(out, dim=0),) + return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) class GrowMaskWithBlur: @classmethod @@ -173,6 +177,7 @@ class GrowMaskWithBlur: else: expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change output = torch.from_numpy(output) + print(output.shape) out.append(output) blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)