Repeat output for spline editor

This commit is contained in:
kijai 2024-04-26 01:18:36 +03:00
parent 6e0784801a
commit 8c5d9ae4ad
2 changed files with 11 additions and 14 deletions

View File

@ -4771,7 +4771,7 @@ class SplineEditor:
"default": 'cardinal'
}),
"tension": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"segmented": ("BOOLEAN", {"default": False}),
"repeat_output": ("INT", {"default": 1, "min": 1, "max": 4096, "step": 1}),
"float_output_type": (
[
'list',
@ -4822,37 +4822,34 @@ output types:
example compatible nodes: unknown
"""
def splinedata(self, mask_width, mask_height, coordinates, float_output_type, interpolation, points_to_sample, points_store, tension, segmented):
def splinedata(self, mask_width, mask_height, coordinates, float_output_type, interpolation, points_to_sample, points_store, tension, repeat_output):
coordinates = json.loads(coordinates)
normalized_y_values = [
1.0 - (point['y'] / 512)
for point in coordinates
]
if float_output_type == 'list':
out_floats = normalized_y_values
out_floats = normalized_y_values * repeat_output
elif float_output_type == 'list of lists':
out_floats = [[value] for value in normalized_y_values],
out_floats = ([[value] for value in normalized_y_values] * repeat_output),
elif float_output_type == 'pandas series':
try:
import pandas as pd
except:
raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type")
out_floats = pd.Series(normalized_y_values),
out_floats = pd.Series(normalized_y_values * repeat_output),
elif float_output_type == 'tensor':
out_floats = torch.tensor(normalized_y_values, dtype=torch.float32)
out_floats = torch.tensor(normalized_y_values * repeat_output, dtype=torch.float32)
# Create a color map for grayscale intensities
color_map = lambda y: torch.full((mask_height, mask_width, 3), y, dtype=torch.float32)
# Create image tensors for each normalized y value
image_tensors = [color_map(y) for y in normalized_y_values]
# Batch the tensors
masks_out = torch.stack(image_tensors)
mask_tensors = [color_map(y) for y in normalized_y_values]
masks_out = torch.stack(mask_tensors)
masks_out = masks_out.repeat(repeat_output, 1, 1, 1)
masks_out = masks_out.mean(dim=-1)
print(masks_out.shape)
return (masks_out, coordinates, out_floats,)
return (masks_out, str(coordinates), out_floats,)
class StabilityAPI_SD3:

View File

@ -170,7 +170,7 @@ function createSplineEditor(context, reset=false) {
const pointsWidget = context.widgets.find(w => w.name === "points_to_sample");
const pointsStoreWidget = context.widgets.find(w => w.name === "points_store");
const tensionWidget = context.widgets.find(w => w.name === "tension");
const segmentedWidget = context.widgets.find(w => w.name === "segmented");
//const segmentedWidget = context.widgets.find(w => w.name === "segmented");
var interpolation = interpolationWidget.value
var tension = tensionWidget.value