diff --git a/nodes.py b/nodes.py index f9f1f94..36f829f 100644 --- a/nodes.py +++ b/nodes.py @@ -2031,6 +2031,71 @@ class WidgetToString: raise NameError(f"Node not found: {id}") return (', '.join(results).strip(', '), ) +class CreateShapeMask: + + RETURN_TYPES = ("MASK", "MASK",) + RETURN_NAMES = ("mask", "mask_inverted",) + FUNCTION = "createshapemask" + CATEGORY = "KJNodes/masking/generate" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "shape": ( + [ 'circle', + 'square', + 'triangle', + ], + { + "default": 'circle' + }), + "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), + "location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), + "location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), + "size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), + "grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), + "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + }, + } + + def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, size, grow, shape): + # Define the number of images in the batch + batch_size = frames + out = [] + color = "white" + for i in range(batch_size): + image = Image.new("RGB", (frame_width, frame_height), "black") + draw = ImageDraw.Draw(image) + + # Calculate the size for this frame and ensure it's not less than 0 + current_size = max(0, size + i*grow) + + if shape == 'circle' or shape == 'square': + # Define the bounding box for the shape + left_up_point = (location_x - current_size // 2, location_y - current_size // 2) + right_down_point = (location_x + current_size // 2, location_y + current_size // 2) + two_points = [left_up_point, right_down_point] + + if shape == 'circle': + draw.ellipse(two_points, fill=color) + elif shape == 'square': + draw.rectangle(two_points, fill=color) + + elif shape == 'triangle': + # Define the points for the triangle + left_up_point = (location_x - current_size // 2, location_y + current_size // 2) # bottom left + right_down_point = (location_x + current_size // 2, location_y + current_size // 2) # bottom right + top_point = (location_x, location_y - current_size // 2) # top point + draw.polygon([top_point, left_up_point, right_down_point], fill=color) + + image = pil2tensor(image) + mask = image[:, :, :, 0] + out.append(mask) + + return (torch.cat(out, dim=0), 1.0 - torch.cat(out, dim=0),) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -2068,6 +2133,7 @@ NODE_CLASS_MAPPINGS = { "ResizeMask": ResizeMask, "OffsetMask": OffsetMask, "WidgetToString": WidgetToString, + "CreateShapeMask": CreateShapeMask, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -2105,4 +2171,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ResizeMask": "ResizeMask", "OffsetMask": "OffsetMask", "WidgetToString": "WidgetToString", + "CreateShapeMask": "CreateShapeMask", } \ No newline at end of file