Proper point sampling for the spline editor

This commit is contained in:
kijai 2024-04-28 21:27:17 +03:00
parent 0843356c78
commit 1bb4b9bd26
3 changed files with 113 additions and 46 deletions

View File

@ -93,6 +93,7 @@ NODE_CLASS_MAPPINGS = {
"MaskOrImageToWeight": MaskOrImageToWeight,
"WeightScheduleConvert": WeightScheduleConvert,
"FloatToMask": FloatToMask,
"FloatToSigmas": FloatToSigmas,
#experimental
"StabilityAPI_SD3": StabilityAPI_SD3,
"SoundReactive": SoundReactive,
@ -189,6 +190,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MaskOrImageToWeight": "Mask Or Image To Weight",
"WeightScheduleConvert": "Weight Schedule Convert",
"FloatToMask": "Float To Mask",
"FloatToSigmas": "Float To Sigmas",
"CustomSigmas": "Custom Sigmas",
"ImagePass": "ImagePass",
#curve nodes

View File

@ -34,7 +34,6 @@ class SplineEditor:
"float_output_type": (
[
'list',
'list of lists',
'pandas series',
'tensor',
],
@ -76,8 +75,6 @@ output types:
example compatible nodes: anything that takes masks
- list of floats
example compatible nodes: IPAdapter weights
- list of lists
example compatible nodes: unknown
- pandas series
example compatible nodes: anything that takes Fizz'
nodes Batch Value Schedule
@ -92,14 +89,13 @@ output types:
for coord in coordinates:
coord['x'] = int(round(coord['x']))
coord['y'] = int(round(coord['y']))
normalized_y_values = [
(1.0 - (point['y'] / 512) - 0.0) * (max_value - min_value) + min_value
for point in coordinates
]
if float_output_type == 'list':
out_floats = normalized_y_values * repeat_output
elif float_output_type == 'list of lists':
out_floats = ([[value] for value in normalized_y_values] * repeat_output),
elif float_output_type == 'pandas series':
try:
import pandas as pd
@ -207,7 +203,6 @@ class MaskOrImageToWeight:
"output_type": (
[
'list',
'list of lists',
'pandas series',
'tensor',
],
@ -243,8 +238,6 @@ and returns that as the selected output type.
# Convert mean_values to the specified output_type
if output_type == 'list':
return mean_values,
elif output_type == 'list of lists':
return [[value] for value in mean_values],
elif output_type == 'pandas series':
try:
import pandas as pd
@ -267,7 +260,6 @@ class WeightScheduleConvert:
[
'match_input',
'list',
'list of lists',
'pandas series',
'tensor',
],
@ -298,8 +290,6 @@ Converts different value lists/series to another type.
return 'pandas series'
elif isinstance(input_values, torch.Tensor):
return 'tensor'
elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values):
return 'list of lists'
else:
raise ValueError("Unsupported input type")
@ -307,9 +297,7 @@ Converts different value lists/series to another type.
import pandas as pd
input_type = self.detect_input_type(input_values)
if input_type == 'list of lists':
float_values = [item for sublist in input_values for item in sublist]
elif input_type == 'pandas series':
if input_type == 'pandas series':
float_values = input_values.tolist()
elif input_type == 'tensor':
float_values = input_values
@ -345,13 +333,13 @@ Converts different value lists/series to another type.
if output_type == 'list':
return float_values,
elif output_type == 'list of lists':
return [[value] for value in float_values],
elif output_type == 'pandas series':
return pd.Series(float_values),
elif output_type == 'tensor':
if input_type == 'pandas series':
return torch.tensor(input_values.values, dtype=torch.float32),
return torch.tensor(float_values.values, dtype=torch.float32),
else:
return torch.tensor(float_values, dtype=torch.float32),
elif output_type == 'match_input':
return float_values,
else:
@ -408,7 +396,6 @@ class WeightScheduleExtend:
[
'match_input',
'list',
'list of lists',
'pandas series',
'tensor',
],
@ -433,8 +420,6 @@ Extends, and converts if needed, different value lists/series
return 'pandas series'
elif isinstance(input_values, torch.Tensor):
return 'tensor'
elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values):
return 'list of lists'
else:
raise ValueError("Unsupported input type")
@ -445,10 +430,7 @@ Extends, and converts if needed, different value lists/series
# Convert input_values_2 to the same format as input_values_1 if they do not match
if not input_type_1 == input_type_2:
print("Converting input_values_2 to the same format as input_values_1")
if input_type_1 == 'list of lists':
# Assuming input_values_2 is a flat list, convert it to a list of lists
float_values_2 = [[item] for item in input_values_2]
elif input_type_1 == 'pandas series':
if input_type_1 == 'pandas series':
# Convert input_values_2 to a pandas Series
float_values_2 = pd.Series(input_values_2)
elif input_type_1 == 'tensor':
@ -463,14 +445,33 @@ Extends, and converts if needed, different value lists/series
if output_type == 'list':
return float_values,
elif output_type == 'list of lists':
return [[value] for value in float_values],
elif output_type == 'pandas series':
return pd.Series(float_values),
elif output_type == 'tensor':
if input_type_1 == 'pandas series':
return torch.tensor(input_values_1.values, dtype=torch.float32),
return torch.tensor(float_values.values, dtype=torch.float32),
else:
return torch.tensor(float_values, dtype=torch.float32),
elif output_type == 'match_input':
return float_values,
else:
raise ValueError(f"Unsupported output_type: {output_type}")
raise ValueError(f"Unsupported output_type: {output_type}")
class FloatToSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"float_list": ("FLOAT", {"default": 0.0, "forceInput": True}),
}
}
RETURN_TYPES = ("SIGMAS",)
RETURN_NAMES = ("SIGMAS",)
CATEGORY = "KJNodes/noise"
FUNCTION = "customsigmas"
DESCRIPTION = """
Creates a sigmas tensor from list of float values.
"""
def customsigmas(self, float_list):
return torch.tensor(float_list, dtype=torch.float32),

View File

@ -148,23 +148,27 @@ app.registerExtension({
this.menuItem2.textContent = "Display sample points";
styleMenuItem(this.menuItem2);
// Add hover effect to menu items
this.menuItem1.addEventListener('mouseover', function() {
this.style.backgroundColor = "gray";
this.menuItem3 = document.createElement("a");
this.menuItem3.href = "#";
this.menuItem3.id = "menu-item-2";
this.menuItem3.textContent = "Switch sampling method";
styleMenuItem(this.menuItem3);
const menuItems = [this.menuItem1, this.menuItem2, this.menuItem3];
menuItems.forEach(menuItem => {
menuItem.addEventListener('mouseover', function() {
this.style.backgroundColor = "gray";
});
menuItem.addEventListener('mouseout', function() {
this.style.backgroundColor = "#202020";
});
this.menuItem1.addEventListener('mouseout', function() {
this.style.backgroundColor = "#202020";
});
this.menuItem2.addEventListener('mouseover', function() {
this.style.backgroundColor = "gray";
// Append menu items to the context menu
menuItems.forEach(menuItem => {
this.contextMenu.appendChild(menuItem);
});
this.menuItem2.addEventListener('mouseout', function() {
this.style.backgroundColor = "#202020";
});
this.contextMenu.appendChild(this.menuItem1);
this.contextMenu.appendChild(this.menuItem2);
document.body.appendChild( this.contextMenu);
@ -241,15 +245,26 @@ function createSplineEditor(context, reset=false) {
context.menuItem2.addEventListener('click', function(e) {
e.preventDefault();
drawSamplePoints = !drawSamplePoints;
updatePath();
});
context.menuItem3.addEventListener('click', function(e) {
e.preventDefault();
if (pointSamplingMethod == samplePointsTime) {
pointSamplingMethod = samplePointsPath
}
else {
pointSamplingMethod = samplePointsTime
}
updatePath();
});
var drawSamplePoints = false;
var pointSamplingMethod = samplePointsTime
function updatePath() {
points_to_sample = pointsWidget.value
let coords = samplePoints(pathElements[0], points_to_sample);
let coords = pointSamplingMethod(pathElements[0], points_to_sample);
if (drawSamplePoints) {
if (pointsLayer) {
// Update the data of the existing points layer
@ -315,9 +330,11 @@ function createSplineEditor(context, reset=false) {
}
minValueWidget.callback = () => {
rangeMin = minValueWidget.value
updatePath();
}
maxValueWidget.callback = () => {
rangeMax = maxValueWidget.value
updatePath();
}
@ -444,7 +461,6 @@ function createSplineEditor(context, reset=false) {
.font(12 + "px sans-serif")
.text(d => {
// Normalize y to range 0.0 to 1.0, considering the inverted y-axis
let normalizedY = (1.0 - (d.y / h) - 0.0) * (rangeMax - rangeMin) + rangeMin;
let normalizedX = (d.x / w);
let frame = Math.round((d.x / w) * points_to_sample);
@ -461,13 +477,14 @@ function createSplineEditor(context, reset=false) {
updatePath();
}
function samplePoints(svgPathElement, numSamples) {
function samplePointsPath(svgPathElement, numSamples) {
var pathLength = svgPathElement.getTotalLength();
var points = [];
for (var i = 0; i < numSamples; i++) {
// Calculate the distance along the path for the current sample
var distance = (pathLength / (numSamples - 1)) * i;
console.log(distance)
// Get the point at the current distance
var point = svgPathElement.getPointAtLength(distance);
@ -475,10 +492,57 @@ function samplePoints(svgPathElement, numSamples) {
// Add the point to the array of points
points.push({ x: point.x, y: point.y });
}
//console.log(points);
console.log(points);
return points;
}
function samplePointsTime(svgPathElement, numSamples) {
var svgWidth = 512; // Fixed width of the SVG element
var pathLength = svgPathElement.getTotalLength();
var points = [];
for (var i = 0; i < numSamples; i++) {
// Calculate the x-coordinate for the current sample based on the SVG's width
var x = (svgWidth / (numSamples - 1)) * i;
// Find the point on the path that intersects the vertical line at the calculated x-coordinate
var point = findPointAtX(svgPathElement, x, pathLength);
// Add the point to the array of points
points.push({ x: point.x, y: point.y });
}
return points;
}
function findPointAtX(svgPathElement, targetX, pathLength) {
let low = 0;
let high = pathLength;
let bestPoint = svgPathElement.getPointAtLength(0);
while (low <= high) {
let mid = low + (high - low) / 2;
let point = svgPathElement.getPointAtLength(mid);
if (Math.abs(point.x - targetX) < 1) {
return point; // The point is close enough to the target
}
if (point.x < targetX) {
low = mid + 1;
} else {
high = mid - 1;
}
// Keep track of the closest point found so far
if (Math.abs(point.x - targetX) < Math.abs(bestPoint.x - targetX)) {
bestPoint = point;
}
}
// Return the closest point found
return bestPoint;
}
//from melmass
export function hideWidgetForGood(node, widget, suffix = '') {
widget.origType = widget.type