This commit is contained in:
kijai 2024-10-22 18:15:42 +03:00
parent b3e5108ad4
commit 0227f7b77f
2 changed files with 50 additions and 28 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from torchvision import transforms from torchvision import transforms
import json import json
from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter, ImageChops
import numpy as np import numpy as np
from ..utility.utility import pil2tensor from ..utility.utility import pil2tensor
import folder_paths import folder_paths
@ -364,16 +364,23 @@ Locations are center locations.
}, },
"optional": { "optional": {
"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}), "size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
"trailing": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
} }
} }
def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape_color, def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape_color,
bg_color, blur_radius, shape, intensity, size_multiplier=[1.0]): bg_color, blur_radius, shape, intensity, size_multiplier=[1.0], accumulate=False, trailing=1.0):
# Define the number of images in the batch # Define the number of images in the batch
coordinates = coordinates.replace("'", '"') if len(coordinates) < 10:
coordinates = json.loads(coordinates) coords_list = []
for coords in coordinates:
coords = json.loads(coords.replace("'", '"'))
coords_list.append(coords)
else:
coords = json.loads(coordinates.replace("'", '"'))
coords_list = [coords]
batch_size = len(coordinates) batch_size = len(coords_list[0])
images_list = [] images_list = []
masks_list = [] masks_list = []
@ -381,7 +388,10 @@ Locations are center locations.
size_multiplier = [0] * batch_size size_multiplier = [0] * batch_size
else: else:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)] size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i, coord in enumerate(coordinates):
previous_output = None
for i in range(batch_size):
image = Image.new("RGB", (frame_width, frame_height), bg_color) image = Image.new("RGB", (frame_width, frame_height), bg_color)
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
@ -389,8 +399,9 @@ Locations are center locations.
current_width = max(0, shape_width + i * size_multiplier[i]) current_width = max(0, shape_width + i * size_multiplier[i])
current_height = max(0, shape_height + i * size_multiplier[i]) current_height = max(0, shape_height + i * size_multiplier[i])
location_x = coord['x'] for coords in coords_list:
location_y = coord['y'] location_x = coords[i]['x']
location_y = coords[i]['y']
if shape == 'circle' or shape == 'square': if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape # Define the bounding box for the shape
@ -412,8 +423,14 @@ Locations are center locations.
if blur_radius != 0: if blur_radius != 0:
image = image.filter(ImageFilter.GaussianBlur(blur_radius)) image = image.filter(ImageFilter.GaussianBlur(blur_radius))
# Blend the current image with the accumulated image
image = pil2tensor(image) image = pil2tensor(image)
if trailing != 1.0 and previous_output is not None:
# Add the decayed previous output to the current frame
image += trailing * previous_output
image = image / image.max()
previous_output = image
image = image * intensity image = image * intensity
mask = image[:, :, :, 0] mask = image[:, :, :, 0]
masks_list.append(mask) masks_list.append(mask)

View File

@ -233,7 +233,12 @@ class AppendStringsToList:
CATEGORY = "KJNodes/constants" CATEGORY = "KJNodes/constants"
def joinstring(self, string1, string2): def joinstring(self, string1, string2):
joined_string = [string1, string2] if not isinstance(string1, list):
string1 = [string1]
if not isinstance(string2, list):
string2 = [string2]
joined_string = string1 + string2
return (joined_string, ) return (joined_string, )
class JoinStrings: class JoinStrings: