This commit is contained in:
kijai 2024-10-22 18:15:42 +03:00
parent b3e5108ad4
commit 0227f7b77f
2 changed files with 50 additions and 28 deletions

View File

@ -1,7 +1,7 @@
import torch
from torchvision import transforms
import json
from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter
from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter, ImageChops
import numpy as np
from ..utility.utility import pil2tensor
import folder_paths
@ -364,16 +364,23 @@ Locations are center locations.
},
"optional": {
"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
"trailing": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}
}
def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape_color,
bg_color, blur_radius, shape, intensity, size_multiplier=[1.0]):
bg_color, blur_radius, shape, intensity, size_multiplier=[1.0], accumulate=False, trailing=1.0):
# Define the number of images in the batch
coordinates = coordinates.replace("'", '"')
coordinates = json.loads(coordinates)
if len(coordinates) < 10:
coords_list = []
for coords in coordinates:
coords = json.loads(coords.replace("'", '"'))
coords_list.append(coords)
else:
coords = json.loads(coordinates.replace("'", '"'))
coords_list = [coords]
batch_size = len(coordinates)
batch_size = len(coords_list[0])
images_list = []
masks_list = []
@ -381,39 +388,49 @@ Locations are center locations.
size_multiplier = [0] * batch_size
else:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i, coord in enumerate(coordinates):
previous_output = None
for i in range(batch_size):
image = Image.new("RGB", (frame_width, frame_height), bg_color)
draw = ImageDraw.Draw(image)
# Calculate the size for this frame and ensure it's not less than 0
current_width = max(0, shape_width + i * size_multiplier[i])
current_height = max(0, shape_height + i * size_multiplier[i])
for coords in coords_list:
location_x = coords[i]['x']
location_y = coords[i]['y']
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
two_points = [left_up_point, right_down_point]
location_x = coord['x']
location_y = coord['y']
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
two_points = [left_up_point, right_down_point]
if shape == 'circle':
draw.ellipse(two_points, fill=shape_color)
elif shape == 'square':
draw.rectangle(two_points, fill=shape_color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
top_point = (location_x, location_y - current_height // 2) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=shape_color)
if shape == 'circle':
draw.ellipse(two_points, fill=shape_color)
elif shape == 'square':
draw.rectangle(two_points, fill=shape_color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
top_point = (location_x, location_y - current_height // 2) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=shape_color)
if blur_radius != 0:
image = image.filter(ImageFilter.GaussianBlur(blur_radius))
# Blend the current image with the accumulated image
image = pil2tensor(image)
if trailing != 1.0 and previous_output is not None:
# Add the decayed previous output to the current frame
image += trailing * previous_output
image = image / image.max()
previous_output = image
image = image * intensity
mask = image[:, :, :, 0]
masks_list.append(mask)

View File

@ -233,7 +233,12 @@ class AppendStringsToList:
CATEGORY = "KJNodes/constants"
def joinstring(self, string1, string2):
joined_string = [string1, string2]
if not isinstance(string1, list):
string1 = [string1]
if not isinstance(string2, list):
string2 = [string2]
joined_string = string1 + string2
return (joined_string, )
class JoinStrings: