diff --git a/nodes/curve_nodes.py b/nodes/curve_nodes.py index 93112d6..6ac2238 100644 --- a/nodes/curve_nodes.py +++ b/nodes/curve_nodes.py @@ -305,8 +305,8 @@ Locations are center locations. class CreateShapeImageOnPath: - RETURN_TYPES = ("IMAGE", ) - RETURN_NAMES = ("image", ) + RETURN_TYPES = ("IMAGE", "MASK",) + RETURN_NAMES = ("image","mask", ) FUNCTION = "createshapemask" CATEGORY = "KJNodes/image" DESCRIPTION = """ @@ -348,7 +348,8 @@ Locations are center locations. coordinates = json.loads(coordinates) batch_size = len(coordinates) - out = [] + images_list = [] + masks_list = [] if len(size_multiplier) != batch_size: size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)] @@ -386,9 +387,12 @@ Locations are center locations. image = pil2tensor(image) image = image * intensity - out.append(image) - outstack = torch.cat(out, dim=0) - return (outstack,) + mask = image[:, :, :, 0] + masks_list.append(mask) + images_list.append(image) + out_images = torch.cat(images_list, dim=0).cpu().float() + out_masks = torch.cat(masks_list, dim=0) + return (out_images, out_masks) class CreateTextOnPath: