Initial version of voronoi mask

works but WIP
This commit is contained in:
kijai 2023-11-19 03:03:25 +02:00
parent b408a75579
commit 02b00c6480

View File

@ -4,6 +4,8 @@ import torch.nn.functional as F
from torchvision.transforms import Resize, CenterCrop, InterpolationMode from torchvision.transforms import Resize, CenterCrop, InterpolationMode
from torchvision.transforms import functional as TF from torchvision.transforms import functional as TF
import scipy.ndimage import scipy.ndimage
from scipy.spatial import Voronoi, voronoi_plot_2d
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from PIL import ImageFilter, Image, ImageDraw, ImageFont from PIL import ImageFilter, Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -2115,6 +2117,70 @@ class CreateShapeMask:
return (torch.cat(out, dim=0), 1.0 - torch.cat(out, dim=0),) return (torch.cat(out, dim=0), 1.0 - torch.cat(out, dim=0),)
class CreateVoronoiMask:
RETURN_TYPES = ("MASK", "MASK",)
RETURN_NAMES = ("mask", "mask_inverted",)
FUNCTION = "createvoronoi"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
"num_points": ("INT", {"default": 15,"min": 1, "max": 4096, "step": 1}),
"line_width": ("INT", {"default": 4,"min": 1, "max": 4096, "step": 1}),
"speed": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}),
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
},
}
def createvoronoi(self, frames, num_points, line_width, speed, frame_width, frame_height):
# Define the number of images in the batch
batch_size = frames
out = []
# Create start and end points for each point
start_points = np.random.rand(num_points, 2)
end_points = np.random.rand(num_points, 2)
for i in range(batch_size):
# Interpolate the points' positions based on the current frame
t = (i * speed) / (batch_size - 1) # normalize to [0, 1] over the frames
t = np.clip(t, 0, 1) # ensure t is in [0, 1]
points = (1 - t) * start_points + t * end_points # lerp
vor = Voronoi(points)
# Create a blank image with a white background
fig, ax = plt.subplots()
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax.set_xlim([0, 1]); ax.set_ylim([0, 1])
ax.axis('off')
ax.margins(0, 0)
fig.set_size_inches(frame_width/100, frame_height/100)
ax.fill_between([0, 1], [0, 1], color='white')
# Plot each Voronoi ridge
for simplex in vor.ridge_vertices:
simplex = np.asarray(simplex)
if np.all(simplex >= 0):
plt.plot(vor.vertices[simplex, 0], vor.vertices[simplex, 1], 'k-', linewidth=line_width)
fig.canvas.draw()
img = np.array(fig.canvas.renderer._renderer)
plt.close(fig)
pil_img = Image.fromarray(img).convert("L")
mask = torch.tensor(np.array(pil_img)) / 255.0
out.append(mask)
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant, "INTConstant": INTConstant,
"FloatConstant": FloatConstant, "FloatConstant": FloatConstant,
@ -2153,6 +2219,7 @@ NODE_CLASS_MAPPINGS = {
"OffsetMask": OffsetMask, "OffsetMask": OffsetMask,
"WidgetToString": WidgetToString, "WidgetToString": WidgetToString,
"CreateShapeMask": CreateShapeMask, "CreateShapeMask": CreateShapeMask,
"CreateVoronoiMask": CreateVoronoiMask,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant", "INTConstant": "INT Constant",
@ -2191,4 +2258,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"OffsetMask": "OffsetMask", "OffsetMask": "OffsetMask",
"WidgetToString": "WidgetToString", "WidgetToString": "WidgetToString",
"CreateShapeMask": "CreateShapeMask", "CreateShapeMask": "CreateShapeMask",
"CreateVoronoiMask": "CreateVoronoiMask",
} }