Add CreateShapeMask

This commit is contained in:
kijai 2023-11-17 10:25:22 +02:00
parent 85140734f0
commit 5aeeed35f6

View File

@ -2031,6 +2031,71 @@ class WidgetToString:
raise NameError(f"Node not found: {id}")
return (', '.join(results).strip(', '), )
class CreateShapeMask:
RETURN_TYPES = ("MASK", "MASK",)
RETURN_NAMES = ("mask", "mask_inverted",)
FUNCTION = "createshapemask"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"shape": (
[ 'circle',
'square',
'triangle',
],
{
"default": 'circle'
}),
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
"location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
"location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
"size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
"grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
},
}
def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, size, grow, shape):
# Define the number of images in the batch
batch_size = frames
out = []
color = "white"
for i in range(batch_size):
image = Image.new("RGB", (frame_width, frame_height), "black")
draw = ImageDraw.Draw(image)
# Calculate the size for this frame and ensure it's not less than 0
current_size = max(0, size + i*grow)
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - current_size // 2, location_y - current_size // 2)
right_down_point = (location_x + current_size // 2, location_y + current_size // 2)
two_points = [left_up_point, right_down_point]
if shape == 'circle':
draw.ellipse(two_points, fill=color)
elif shape == 'square':
draw.rectangle(two_points, fill=color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - current_size // 2, location_y + current_size // 2) # bottom left
right_down_point = (location_x + current_size // 2, location_y + current_size // 2) # bottom right
top_point = (location_x, location_y - current_size // 2) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=color)
image = pil2tensor(image)
mask = image[:, :, :, 0]
out.append(mask)
return (torch.cat(out, dim=0), 1.0 - torch.cat(out, dim=0),)
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
"FloatConstant": FloatConstant,
@ -2068,6 +2133,7 @@ NODE_CLASS_MAPPINGS = {
"ResizeMask": ResizeMask,
"OffsetMask": OffsetMask,
"WidgetToString": WidgetToString,
"CreateShapeMask": CreateShapeMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -2105,4 +2171,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ResizeMask": "ResizeMask",
"OffsetMask": "OffsetMask",
"WidgetToString": "WidgetToString",
"CreateShapeMask": "CreateShapeMask",
}