diff --git a/__init__.py b/__init__.py index f8d46fe..09f0957 100644 --- a/__init__.py +++ b/__init__.py @@ -93,6 +93,7 @@ NODE_CLASS_MAPPINGS = { "MaskOrImageToWeight": MaskOrImageToWeight, "WeightScheduleConvert": WeightScheduleConvert, "FloatToMask": FloatToMask, + "FloatToSigmas": FloatToSigmas, #experimental "StabilityAPI_SD3": StabilityAPI_SD3, "SoundReactive": SoundReactive, @@ -189,6 +190,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MaskOrImageToWeight": "Mask Or Image To Weight", "WeightScheduleConvert": "Weight Schedule Convert", "FloatToMask": "Float To Mask", + "FloatToSigmas": "Float To Sigmas", "CustomSigmas": "Custom Sigmas", "ImagePass": "ImagePass", #curve nodes diff --git a/nodes/curve_nodes.py b/nodes/curve_nodes.py index fd25d4f..a723f06 100644 --- a/nodes/curve_nodes.py +++ b/nodes/curve_nodes.py @@ -34,7 +34,6 @@ class SplineEditor: "float_output_type": ( [ 'list', - 'list of lists', 'pandas series', 'tensor', ], @@ -76,8 +75,6 @@ output types: example compatible nodes: anything that takes masks - list of floats example compatible nodes: IPAdapter weights - - list of lists - example compatible nodes: unknown - pandas series example compatible nodes: anything that takes Fizz' nodes Batch Value Schedule @@ -92,14 +89,13 @@ output types: for coord in coordinates: coord['x'] = int(round(coord['x'])) coord['y'] = int(round(coord['y'])) + normalized_y_values = [ (1.0 - (point['y'] / 512) - 0.0) * (max_value - min_value) + min_value for point in coordinates ] if float_output_type == 'list': out_floats = normalized_y_values * repeat_output - elif float_output_type == 'list of lists': - out_floats = ([[value] for value in normalized_y_values] * repeat_output), elif float_output_type == 'pandas series': try: import pandas as pd @@ -207,7 +203,6 @@ class MaskOrImageToWeight: "output_type": ( [ 'list', - 'list of lists', 'pandas series', 'tensor', ], @@ -243,8 +238,6 @@ and returns that as the selected output type. # Convert mean_values to the specified output_type if output_type == 'list': return mean_values, - elif output_type == 'list of lists': - return [[value] for value in mean_values], elif output_type == 'pandas series': try: import pandas as pd @@ -267,7 +260,6 @@ class WeightScheduleConvert: [ 'match_input', 'list', - 'list of lists', 'pandas series', 'tensor', ], @@ -298,8 +290,6 @@ Converts different value lists/series to another type. return 'pandas series' elif isinstance(input_values, torch.Tensor): return 'tensor' - elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values): - return 'list of lists' else: raise ValueError("Unsupported input type") @@ -307,9 +297,7 @@ Converts different value lists/series to another type. import pandas as pd input_type = self.detect_input_type(input_values) - if input_type == 'list of lists': - float_values = [item for sublist in input_values for item in sublist] - elif input_type == 'pandas series': + if input_type == 'pandas series': float_values = input_values.tolist() elif input_type == 'tensor': float_values = input_values @@ -345,13 +333,13 @@ Converts different value lists/series to another type. if output_type == 'list': return float_values, - elif output_type == 'list of lists': - return [[value] for value in float_values], elif output_type == 'pandas series': return pd.Series(float_values), elif output_type == 'tensor': if input_type == 'pandas series': - return torch.tensor(input_values.values, dtype=torch.float32), + return torch.tensor(float_values.values, dtype=torch.float32), + else: + return torch.tensor(float_values, dtype=torch.float32), elif output_type == 'match_input': return float_values, else: @@ -408,7 +396,6 @@ class WeightScheduleExtend: [ 'match_input', 'list', - 'list of lists', 'pandas series', 'tensor', ], @@ -433,8 +420,6 @@ Extends, and converts if needed, different value lists/series return 'pandas series' elif isinstance(input_values, torch.Tensor): return 'tensor' - elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values): - return 'list of lists' else: raise ValueError("Unsupported input type") @@ -445,10 +430,7 @@ Extends, and converts if needed, different value lists/series # Convert input_values_2 to the same format as input_values_1 if they do not match if not input_type_1 == input_type_2: print("Converting input_values_2 to the same format as input_values_1") - if input_type_1 == 'list of lists': - # Assuming input_values_2 is a flat list, convert it to a list of lists - float_values_2 = [[item] for item in input_values_2] - elif input_type_1 == 'pandas series': + if input_type_1 == 'pandas series': # Convert input_values_2 to a pandas Series float_values_2 = pd.Series(input_values_2) elif input_type_1 == 'tensor': @@ -463,14 +445,33 @@ Extends, and converts if needed, different value lists/series if output_type == 'list': return float_values, - elif output_type == 'list of lists': - return [[value] for value in float_values], elif output_type == 'pandas series': return pd.Series(float_values), elif output_type == 'tensor': if input_type_1 == 'pandas series': - return torch.tensor(input_values_1.values, dtype=torch.float32), + return torch.tensor(float_values.values, dtype=torch.float32), + else: + return torch.tensor(float_values, dtype=torch.float32), elif output_type == 'match_input': return float_values, else: - raise ValueError(f"Unsupported output_type: {output_type}") \ No newline at end of file + raise ValueError(f"Unsupported output_type: {output_type}") + +class FloatToSigmas: + @classmethod + def INPUT_TYPES(s): + return {"required": + { + "float_list": ("FLOAT", {"default": 0.0, "forceInput": True}), + } + } + RETURN_TYPES = ("SIGMAS",) + RETURN_NAMES = ("SIGMAS",) + CATEGORY = "KJNodes/noise" + FUNCTION = "customsigmas" + DESCRIPTION = """ +Creates a sigmas tensor from list of float values. + +""" + def customsigmas(self, float_list): + return torch.tensor(float_list, dtype=torch.float32), \ No newline at end of file diff --git a/web/js/spline_editor.js b/web/js/spline_editor.js index 3228c5a..6b0bf63 100644 --- a/web/js/spline_editor.js +++ b/web/js/spline_editor.js @@ -148,23 +148,27 @@ app.registerExtension({ this.menuItem2.textContent = "Display sample points"; styleMenuItem(this.menuItem2); - // Add hover effect to menu items - this.menuItem1.addEventListener('mouseover', function() { - this.style.backgroundColor = "gray"; + this.menuItem3 = document.createElement("a"); + this.menuItem3.href = "#"; + this.menuItem3.id = "menu-item-2"; + this.menuItem3.textContent = "Switch sampling method"; + styleMenuItem(this.menuItem3); + + const menuItems = [this.menuItem1, this.menuItem2, this.menuItem3]; + + menuItems.forEach(menuItem => { + menuItem.addEventListener('mouseover', function() { + this.style.backgroundColor = "gray"; + }); + menuItem.addEventListener('mouseout', function() { + this.style.backgroundColor = "#202020"; }); - this.menuItem1.addEventListener('mouseout', function() { - this.style.backgroundColor = "#202020"; }); - this.menuItem2.addEventListener('mouseover', function() { - this.style.backgroundColor = "gray"; + // Append menu items to the context menu + menuItems.forEach(menuItem => { + this.contextMenu.appendChild(menuItem); }); - this.menuItem2.addEventListener('mouseout', function() { - this.style.backgroundColor = "#202020"; - }); - - this.contextMenu.appendChild(this.menuItem1); - this.contextMenu.appendChild(this.menuItem2); document.body.appendChild( this.contextMenu); @@ -241,15 +245,26 @@ function createSplineEditor(context, reset=false) { context.menuItem2.addEventListener('click', function(e) { e.preventDefault(); drawSamplePoints = !drawSamplePoints; - updatePath(); }); + context.menuItem3.addEventListener('click', function(e) { + e.preventDefault(); + if (pointSamplingMethod == samplePointsTime) { + pointSamplingMethod = samplePointsPath + } + else { + pointSamplingMethod = samplePointsTime + } + updatePath(); +}); + var drawSamplePoints = false; + var pointSamplingMethod = samplePointsTime function updatePath() { points_to_sample = pointsWidget.value - let coords = samplePoints(pathElements[0], points_to_sample); + let coords = pointSamplingMethod(pathElements[0], points_to_sample); if (drawSamplePoints) { if (pointsLayer) { // Update the data of the existing points layer @@ -315,9 +330,11 @@ function createSplineEditor(context, reset=false) { } minValueWidget.callback = () => { + rangeMin = minValueWidget.value updatePath(); } maxValueWidget.callback = () => { + rangeMax = maxValueWidget.value updatePath(); } @@ -444,7 +461,6 @@ function createSplineEditor(context, reset=false) { .font(12 + "px sans-serif") .text(d => { - // Normalize y to range 0.0 to 1.0, considering the inverted y-axis let normalizedY = (1.0 - (d.y / h) - 0.0) * (rangeMax - rangeMin) + rangeMin; let normalizedX = (d.x / w); let frame = Math.round((d.x / w) * points_to_sample); @@ -461,13 +477,14 @@ function createSplineEditor(context, reset=false) { updatePath(); } -function samplePoints(svgPathElement, numSamples) { +function samplePointsPath(svgPathElement, numSamples) { var pathLength = svgPathElement.getTotalLength(); var points = []; for (var i = 0; i < numSamples; i++) { // Calculate the distance along the path for the current sample var distance = (pathLength / (numSamples - 1)) * i; + console.log(distance) // Get the point at the current distance var point = svgPathElement.getPointAtLength(distance); @@ -475,10 +492,57 @@ function samplePoints(svgPathElement, numSamples) { // Add the point to the array of points points.push({ x: point.x, y: point.y }); } - //console.log(points); + console.log(points); return points; } +function samplePointsTime(svgPathElement, numSamples) { + var svgWidth = 512; // Fixed width of the SVG element + var pathLength = svgPathElement.getTotalLength(); + var points = []; + + for (var i = 0; i < numSamples; i++) { + // Calculate the x-coordinate for the current sample based on the SVG's width + var x = (svgWidth / (numSamples - 1)) * i; + + // Find the point on the path that intersects the vertical line at the calculated x-coordinate + var point = findPointAtX(svgPathElement, x, pathLength); + + // Add the point to the array of points + points.push({ x: point.x, y: point.y }); + } + return points; +} + +function findPointAtX(svgPathElement, targetX, pathLength) { + let low = 0; + let high = pathLength; + let bestPoint = svgPathElement.getPointAtLength(0); + + while (low <= high) { + let mid = low + (high - low) / 2; + let point = svgPathElement.getPointAtLength(mid); + + if (Math.abs(point.x - targetX) < 1) { + return point; // The point is close enough to the target + } + + if (point.x < targetX) { + low = mid + 1; + } else { + high = mid - 1; + } + + // Keep track of the closest point found so far + if (Math.abs(point.x - targetX) < Math.abs(bestPoint.x - targetX)) { + bestPoint = point; + } + } + + // Return the closest point found + return bestPoint; +} + //from melmass export function hideWidgetForGood(node, widget, suffix = '') { widget.origType = widget.type