Update curve_nodes.py

This commit is contained in:
kijai 2024-05-12 02:08:53 +03:00
parent 9991155130
commit 5631cd0146

View File

@ -305,8 +305,8 @@ Locations are center locations.
class CreateShapeImageOnPath: class CreateShapeImageOnPath:
RETURN_TYPES = ("IMAGE", ) RETURN_TYPES = ("IMAGE", "MASK",)
RETURN_NAMES = ("image", ) RETURN_NAMES = ("image","mask", )
FUNCTION = "createshapemask" FUNCTION = "createshapemask"
CATEGORY = "KJNodes/image" CATEGORY = "KJNodes/image"
DESCRIPTION = """ DESCRIPTION = """
@ -348,7 +348,8 @@ Locations are center locations.
coordinates = json.loads(coordinates) coordinates = json.loads(coordinates)
batch_size = len(coordinates) batch_size = len(coordinates)
out = [] images_list = []
masks_list = []
if len(size_multiplier) != batch_size: if len(size_multiplier) != batch_size:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)] size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
@ -386,9 +387,12 @@ Locations are center locations.
image = pil2tensor(image) image = pil2tensor(image)
image = image * intensity image = image * intensity
out.append(image) mask = image[:, :, :, 0]
outstack = torch.cat(out, dim=0) masks_list.append(mask)
return (outstack,) images_list.append(image)
out_images = torch.cat(images_list, dim=0).cpu().float()
out_masks = torch.cat(masks_list, dim=0)
return (out_images, out_masks)
class CreateTextOnPath: class CreateTextOnPath: