Spline editor updates

This commit is contained in:
Kijai 2024-04-26 18:24:51 +03:00
parent 3275b332f3
commit c7b3a85615
2 changed files with 242 additions and 13 deletions

131
nodes.py
View File

@ -2851,7 +2851,7 @@ Grow value is the amount to grow the shape on each frame, creating animated mask
image = pil2tensor(image)
mask = image[:, :, :, 0]
out.append(mask)
outstack = torch.cat(out, dim=0)
outstack = torch.cat(out, dim=0)
return (outstack, 1.0 - outstack,)
class CreateVoronoiMask:
@ -4826,6 +4826,9 @@ output types:
def splinedata(self, mask_width, mask_height, coordinates, float_output_type, interpolation, points_to_sample, points_store, tension, repeat_output):
coordinates = json.loads(coordinates)
for coord in coordinates:
coord['x'] = int(round(coord['x']))
coord['y'] = int(round(coord['y']))
normalized_y_values = [
1.0 - (point['y'] / 512)
for point in coordinates
@ -4851,7 +4854,87 @@ output types:
masks_out = masks_out.repeat(repeat_output, 1, 1, 1)
masks_out = masks_out.mean(dim=-1)
return (masks_out, str(coordinates), out_floats,)
class CreateShapeMaskOnPath:
RETURN_TYPES = ("MASK", "MASK",)
RETURN_NAMES = ("mask", "mask_inverted",)
FUNCTION = "createshapemask"
CATEGORY = "KJNodes/masking/generate"
DESCRIPTION = """
Creates a mask or batch of masks with the specified shape.
Locations are center locations.
Grow value is the amount to grow the shape on each frame, creating animated masks.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"shape": (
[ 'circle',
'square',
'triangle',
],
{
"default": 'circle'
}),
"coordinates": ("STRING", {"forceInput": True}),
"grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
"shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
},
}
def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, grow, shape):
# Define the number of images in the batch
coordinates = coordinates.replace("'", '"')
coordinates = json.loads(coordinates)
for coord in coordinates:
print(coord)
batch_size = len(coordinates)
print(batch_size)
out = []
color = "white"
for i, coord in enumerate(coordinates):
image = Image.new("RGB", (frame_width, frame_height), "black")
draw = ImageDraw.Draw(image)
# Calculate the size for this frame and ensure it's not less than 0
current_width = max(0, shape_width + i*grow)
current_height = max(0, shape_height + i*grow)
location_x = coord['x']
location_y = coord['y']
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
two_points = [left_up_point, right_down_point]
if shape == 'circle':
draw.ellipse(two_points, fill=color)
elif shape == 'square':
draw.rectangle(two_points, fill=color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
top_point = (location_x, location_y - current_height // 2) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=color)
image = pil2tensor(image)
mask = image[:, :, :, 0]
out.append(mask)
outstack = torch.cat(out, dim=0)
return (outstack, 1.0 - outstack,)
class StabilityAPI_SD3:
@classmethod
@ -5072,6 +5155,7 @@ class WeightScheduleConvert:
"input_values": ("FLOAT", {"default": 0.0, "forceInput": True}),
"output_type": (
[
'match_input',
'list',
'list of lists',
'pandas series',
@ -5080,7 +5164,14 @@ class WeightScheduleConvert:
{
"default": 'list'
}),
"invert": ("BOOLEAN", {"default": False}),
"repeat": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
},
"optional": {
"remap_to_frames": ("INT", {"default": 0}),
"interpolation_curve": ("FLOAT", {"forceInput": True}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
@ -5102,9 +5193,8 @@ Converts different value lists/series to another type.
else:
raise ValueError("Unsupported input type")
def execute(self, input_values, output_type):
def execute(self, input_values, output_type, invert, repeat, remap_to_frames=0, interpolation_curve=None):
import pandas as pd
# Detect the input type
input_type = self.detect_input_type(input_values)
# Convert input_values to a list of floats
@ -5117,6 +5207,35 @@ Converts different value lists/series to another type.
else:
float_values = input_values
if invert:
float_values = [1 - value for value in float_values]
if interpolation_curve is not None:
interpolated_pattern = []
orig_float_values = float_values
for value in interpolation_curve:
print(value)
min_val = min(orig_float_values)
max_val = max(orig_float_values)
# Normalize the values to [0, 1]
normalized_values = [(value - min_val) / (max_val - min_val) for value in orig_float_values]
# Interpolate the normalized values to the new frame count
remapped_float_values = np.interp(np.linspace(0, 1, int(remap_to_frames * value)), np.linspace(0, 1, len(normalized_values)), normalized_values).tolist()
interpolated_pattern.append(remapped_float_values)
print(interpolated_pattern)
float_values = interpolated_pattern
else:
# Remap float_values to match target_frame_amount
if remap_to_frames > 0 and remap_to_frames != len(float_values):
min_val = min(float_values)
max_val = max(float_values)
# Normalize the values to [0, 1]
normalized_values = [(value - min_val) / (max_val - min_val) for value in float_values]
# Interpolate the normalized values to the new frame count
float_values = np.interp(np.linspace(0, 1, remap_to_frames), np.linspace(0, 1, len(normalized_values)), normalized_values).tolist()
float_values = float_values * repeat
if output_type == 'list':
return float_values,
elif output_type == 'list of lists':
@ -5126,6 +5245,8 @@ Converts different value lists/series to another type.
elif output_type == 'tensor':
if input_type == 'pandas series':
return torch.tensor(input_values.values, dtype=torch.float32),
elif output_type == 'match_input':
return float_values,
else:
raise ValueError(f"Unsupported output_type: {output_type}")
@ -5257,7 +5378,8 @@ NODE_CLASS_MAPPINGS = {
"WeightScheduleConvert": WeightScheduleConvert,
"FloatToMask": FloatToMask,
"CustomSigmas": CustomSigmas,
"ImagePass": ImagePass
"ImagePass": ImagePass,
"CreateShapeMaskOnPath": CreateShapeMaskOnPath
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -5347,4 +5469,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"FloatToMask": "Float To Mask",
"CustomSigmas": "Custom Sigmas",
"ImagePass": "ImagePass",
"CreateShapeMaskOnPath": "CreateShapeMaskOnPath",
}

View File

@ -113,6 +113,58 @@ app.registerExtension({
serialize: false,
hideOnZoom: false,
});
// context menu
this.contextMenu = document.createElement("div");
this.contextMenu.id = "context-menu";
this.contextMenu.style.display = "none";
this.contextMenu.style.position = "absolute";
this.contextMenu.style.backgroundColor = "#202020";
this.contextMenu.style.minWidth = "100px";
this.contextMenu.style.boxShadow = "0px 8px 16px 0px rgba(0,0,0,0.2)";
this.contextMenu.style.zIndex = "100";
this.contextMenu.style.padding = "5px";
function styleMenuItem(menuItem) {
menuItem.style.display = "block";
menuItem.style.padding = "5px";
menuItem.style.color = "#FFF";
menuItem.style.fontFamily = "Arial, sans-serif";
menuItem.style.fontSize = "16px";
menuItem.style.textDecoration = "none";
menuItem.style.marginBottom = "5px";
}
this.menuItem1 = document.createElement("a");
this.menuItem1.href = "#";
this.menuItem1.id = "menu-item-1";
this.menuItem1.textContent = "Toggle handles";
styleMenuItem(this.menuItem1);
this.menuItem2 = document.createElement("a");
this.menuItem2.href = "#";
this.menuItem2.id = "menu-item-2";
this.menuItem2.textContent = "Placeholder";
styleMenuItem(this.menuItem2);
// Add hover effect to menu items
this.menuItem1.addEventListener('mouseover', function() {
this.style.backgroundColor = "gray";
});
this.menuItem1.addEventListener('mouseout', function() {
this.style.backgroundColor = "#202020";
});
this.menuItem2.addEventListener('mouseover', function() {
this.style.backgroundColor = "gray";
});
this.menuItem2.addEventListener('mouseout', function() {
this.style.backgroundColor = "#202020";
});
this.contextMenu.appendChild(this.menuItem1);
this.contextMenu.appendChild(this.menuItem2);
document.body.appendChild( this.contextMenu);
this.addWidget("button", "New spline", null, () => {
if (!this.properties || !("points" in this.properties)) {
@ -129,18 +181,20 @@ app.registerExtension({
this.splineEditor.parentEl.id = `spline-editor-${this.uuid}`
element.appendChild(this.splineEditor.parentEl);
//disable context menu on right click
document.addEventListener('contextmenu', function(e) {
if (e.button === 2) { // Right mouse button
e.preventDefault();
e.stopPropagation();
}
})
chainCallback(this, "onGraphConfigured", function() {
console.log('onGraphConfigured');
createSplineEditor(this)
this.setSize([550, 840])
});
//disable context menu on right click
// document.addEventListener('contextmenu', function(e) {
// if (e.button === 2) { // Right mouse button
// e.preventDefault();
// e.stopPropagation();
// }
// })
}); // onAfterGraphConfigured
}//node created
} //before register
@ -150,6 +204,51 @@ app.registerExtension({
function createSplineEditor(context, reset=false) {
console.log("creatingSplineEditor")
document.addEventListener('contextmenu', function(e) {
e.preventDefault();
});
document.addEventListener('click', function(e) {
if (!context.contextMenu.contains(e.target)) {
context.contextMenu.style.display = 'none';
}
});
context.menuItem1.addEventListener('click', function(e) {
e.preventDefault();
if (!drawHandles) {
drawHandles = true
vis.add(pv.Line)
.data(() => points.map((point, index) => ({
start: point,
end: [index]
})))
.left(d => d.start.x)
.top(d => d.start.y)
.interpolate("linear")
.tension(0) // Straight lines
.strokeStyle("#ff7f0e") // Same color as control points
.lineWidth(1)
.visible(() => drawHandles);
vis.render();
} else {
drawHandles = false
vis.render();
}
context.contextMenu.style.display = 'none';
});
context.menuItem2.addEventListener('click', function(e) {
e.preventDefault();
// Add functionality for menu item 2
console.log('Option 2 clicked');
});
function updatePath() {
points_to_sample = pointsWidget.value
let coords = samplePoints(pathElements[0], points_to_sample);
@ -188,6 +287,7 @@ function createSplineEditor(context, reset=false) {
}
// Initialize or reset points array
var drawHandles = false
var w = 512
var h = 512
var i = 3
@ -226,6 +326,11 @@ function createSplineEditor(context, reset=false) {
i = points.push(this.mouse()) - 1;
return this;
}
else if (pv.event.button === 2) {
context.contextMenu.style.display = 'block';
context.contextMenu.style.left = `${pv.event.clientX}px`;
context.contextMenu.style.top = `${pv.event.clientY}px`;
}
})
.event("mouseup", function() {
if (this.pathElements !== null) {
@ -237,7 +342,7 @@ function createSplineEditor(context, reset=false) {
.data(pv.range(0, 8, .5))
.bottom(d => d * 64 + 0)
.strokeStyle("gray")
.lineWidth(1)
.lineWidth(2)
vis.add(pv.Line)
.data(() => points)
@ -291,7 +396,7 @@ function createSplineEditor(context, reset=false) {
.left(d => d.x < w / 2 ? d.x + 80 : d.x - 70) // Shift label to right if on left half, otherwise shift to left
.top(d => d.y < h / 2 ? d.y + 20 : d.y - 20) // Shift label down if on top half, otherwise shift up
.font(12 + "px Consolas")
.font(12 + "px sans-serif")
.text(d => {
// Normalize y to range 0.0 to 1.0, considering the inverted y-axis
var normalizedY = 1.0 - (d.y / h);
@ -301,6 +406,7 @@ function createSplineEditor(context, reset=false) {
})
.textStyle("orange")
vis.render();
var svgElement = vis.canvas();
svgElement.style['zIndex'] = "2"