This commit is contained in:
kijai 2024-10-22 18:15:42 +03:00
parent b3e5108ad4
commit 0227f7b77f
2 changed files with 50 additions and 28 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from torchvision import transforms from torchvision import transforms
import json import json
from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter, ImageChops
import numpy as np import numpy as np
from ..utility.utility import pil2tensor from ..utility.utility import pil2tensor
import folder_paths import folder_paths
@ -364,16 +364,23 @@ Locations are center locations.
}, },
"optional": { "optional": {
"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}), "size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
"trailing": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
} }
} }
def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape_color, def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape_color,
bg_color, blur_radius, shape, intensity, size_multiplier=[1.0]): bg_color, blur_radius, shape, intensity, size_multiplier=[1.0], accumulate=False, trailing=1.0):
# Define the number of images in the batch # Define the number of images in the batch
coordinates = coordinates.replace("'", '"') if len(coordinates) < 10:
coordinates = json.loads(coordinates) coords_list = []
for coords in coordinates:
coords = json.loads(coords.replace("'", '"'))
coords_list.append(coords)
else:
coords = json.loads(coordinates.replace("'", '"'))
coords_list = [coords]
batch_size = len(coordinates) batch_size = len(coords_list[0])
images_list = [] images_list = []
masks_list = [] masks_list = []
@ -381,39 +388,49 @@ Locations are center locations.
size_multiplier = [0] * batch_size size_multiplier = [0] * batch_size
else: else:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)] size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i, coord in enumerate(coordinates):
previous_output = None
for i in range(batch_size):
image = Image.new("RGB", (frame_width, frame_height), bg_color) image = Image.new("RGB", (frame_width, frame_height), bg_color)
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
# Calculate the size for this frame and ensure it's not less than 0 # Calculate the size for this frame and ensure it's not less than 0
current_width = max(0, shape_width + i * size_multiplier[i]) current_width = max(0, shape_width + i * size_multiplier[i])
current_height = max(0, shape_height + i * size_multiplier[i]) current_height = max(0, shape_height + i * size_multiplier[i])
for coords in coords_list:
location_x = coords[i]['x']
location_y = coords[i]['y']
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
two_points = [left_up_point, right_down_point]
location_x = coord['x'] if shape == 'circle':
location_y = coord['y'] draw.ellipse(two_points, fill=shape_color)
elif shape == 'square':
if shape == 'circle' or shape == 'square': draw.rectangle(two_points, fill=shape_color)
# Define the bounding box for the shape
left_up_point = (location_x - current_width // 2, location_y - current_height // 2) elif shape == 'triangle':
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # Define the points for the triangle
two_points = [left_up_point, right_down_point] left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
if shape == 'circle': top_point = (location_x, location_y - current_height // 2) # top point
draw.ellipse(two_points, fill=shape_color) draw.polygon([top_point, left_up_point, right_down_point], fill=shape_color)
elif shape == 'square':
draw.rectangle(two_points, fill=shape_color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
top_point = (location_x, location_y - current_height // 2) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=shape_color)
if blur_radius != 0: if blur_radius != 0:
image = image.filter(ImageFilter.GaussianBlur(blur_radius)) image = image.filter(ImageFilter.GaussianBlur(blur_radius))
# Blend the current image with the accumulated image
image = pil2tensor(image) image = pil2tensor(image)
if trailing != 1.0 and previous_output is not None:
# Add the decayed previous output to the current frame
image += trailing * previous_output
image = image / image.max()
previous_output = image
image = image * intensity image = image * intensity
mask = image[:, :, :, 0] mask = image[:, :, :, 0]
masks_list.append(mask) masks_list.append(mask)

View File

@ -233,7 +233,12 @@ class AppendStringsToList:
CATEGORY = "KJNodes/constants" CATEGORY = "KJNodes/constants"
def joinstring(self, string1, string2): def joinstring(self, string1, string2):
joined_string = [string1, string2] if not isinstance(string1, list):
string1 = [string1]
if not isinstance(string2, list):
string2 = [string2]
joined_string = string1 + string2
return (joined_string, ) return (joined_string, )
class JoinStrings: class JoinStrings: