Initial version of voronoi mask

works but WIP
This commit is contained in:
kijai 2023-11-19 03:03:25 +02:00
parent b408a75579
commit 02b00c6480

View File

@ -4,6 +4,8 @@ import torch.nn.functional as F
from torchvision.transforms import Resize, CenterCrop, InterpolationMode
from torchvision.transforms import functional as TF
import scipy.ndimage
from scipy.spatial import Voronoi, voronoi_plot_2d
import matplotlib.pyplot as plt
import numpy as np
from PIL import ImageFilter, Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
@ -2115,6 +2117,70 @@ class CreateShapeMask:
return (torch.cat(out, dim=0), 1.0 - torch.cat(out, dim=0),)
class CreateVoronoiMask:
RETURN_TYPES = ("MASK", "MASK",)
RETURN_NAMES = ("mask", "mask_inverted",)
FUNCTION = "createvoronoi"
CATEGORY = "KJNodes/masking/generate"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
"num_points": ("INT", {"default": 15,"min": 1, "max": 4096, "step": 1}),
"line_width": ("INT", {"default": 4,"min": 1, "max": 4096, "step": 1}),
"speed": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}),
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
},
}
def createvoronoi(self, frames, num_points, line_width, speed, frame_width, frame_height):
# Define the number of images in the batch
batch_size = frames
out = []
# Create start and end points for each point
start_points = np.random.rand(num_points, 2)
end_points = np.random.rand(num_points, 2)
for i in range(batch_size):
# Interpolate the points' positions based on the current frame
t = (i * speed) / (batch_size - 1) # normalize to [0, 1] over the frames
t = np.clip(t, 0, 1) # ensure t is in [0, 1]
points = (1 - t) * start_points + t * end_points # lerp
vor = Voronoi(points)
# Create a blank image with a white background
fig, ax = plt.subplots()
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax.set_xlim([0, 1]); ax.set_ylim([0, 1])
ax.axis('off')
ax.margins(0, 0)
fig.set_size_inches(frame_width/100, frame_height/100)
ax.fill_between([0, 1], [0, 1], color='white')
# Plot each Voronoi ridge
for simplex in vor.ridge_vertices:
simplex = np.asarray(simplex)
if np.all(simplex >= 0):
plt.plot(vor.vertices[simplex, 0], vor.vertices[simplex, 1], 'k-', linewidth=line_width)
fig.canvas.draw()
img = np.array(fig.canvas.renderer._renderer)
plt.close(fig)
pil_img = Image.fromarray(img).convert("L")
mask = torch.tensor(np.array(pil_img)) / 255.0
out.append(mask)
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
"FloatConstant": FloatConstant,
@ -2153,6 +2219,7 @@ NODE_CLASS_MAPPINGS = {
"OffsetMask": OffsetMask,
"WidgetToString": WidgetToString,
"CreateShapeMask": CreateShapeMask,
"CreateVoronoiMask": CreateVoronoiMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -2191,4 +2258,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"OffsetMask": "OffsetMask",
"WidgetToString": "WidgetToString",
"CreateShapeMask": "CreateShapeMask",
"CreateVoronoiMask": "CreateVoronoiMask",
}