Repeat output for spline editor

This commit is contained in:
kijai 2024-04-26 01:18:36 +03:00
parent 6e0784801a
commit 8c5d9ae4ad
2 changed files with 11 additions and 14 deletions

View File

@ -4771,7 +4771,7 @@ class SplineEditor:
"default": 'cardinal' "default": 'cardinal'
}), }),
"tension": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), "tension": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"segmented": ("BOOLEAN", {"default": False}), "repeat_output": ("INT", {"default": 1, "min": 1, "max": 4096, "step": 1}),
"float_output_type": ( "float_output_type": (
[ [
'list', 'list',
@ -4822,37 +4822,34 @@ output types:
example compatible nodes: unknown example compatible nodes: unknown
""" """
def splinedata(self, mask_width, mask_height, coordinates, float_output_type, interpolation, points_to_sample, points_store, tension, segmented): def splinedata(self, mask_width, mask_height, coordinates, float_output_type, interpolation, points_to_sample, points_store, tension, repeat_output):
coordinates = json.loads(coordinates) coordinates = json.loads(coordinates)
normalized_y_values = [ normalized_y_values = [
1.0 - (point['y'] / 512) 1.0 - (point['y'] / 512)
for point in coordinates for point in coordinates
] ]
if float_output_type == 'list': if float_output_type == 'list':
out_floats = normalized_y_values out_floats = normalized_y_values * repeat_output
elif float_output_type == 'list of lists': elif float_output_type == 'list of lists':
out_floats = [[value] for value in normalized_y_values], out_floats = ([[value] for value in normalized_y_values] * repeat_output),
elif float_output_type == 'pandas series': elif float_output_type == 'pandas series':
try: try:
import pandas as pd import pandas as pd
except: except:
raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type") raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type")
out_floats = pd.Series(normalized_y_values), out_floats = pd.Series(normalized_y_values * repeat_output),
elif float_output_type == 'tensor': elif float_output_type == 'tensor':
out_floats = torch.tensor(normalized_y_values, dtype=torch.float32) out_floats = torch.tensor(normalized_y_values * repeat_output, dtype=torch.float32)
# Create a color map for grayscale intensities # Create a color map for grayscale intensities
color_map = lambda y: torch.full((mask_height, mask_width, 3), y, dtype=torch.float32) color_map = lambda y: torch.full((mask_height, mask_width, 3), y, dtype=torch.float32)
# Create image tensors for each normalized y value # Create image tensors for each normalized y value
image_tensors = [color_map(y) for y in normalized_y_values] mask_tensors = [color_map(y) for y in normalized_y_values]
masks_out = torch.stack(mask_tensors)
# Batch the tensors masks_out = masks_out.repeat(repeat_output, 1, 1, 1)
masks_out = torch.stack(image_tensors)
masks_out = masks_out.mean(dim=-1) masks_out = masks_out.mean(dim=-1)
print(masks_out.shape) return (masks_out, str(coordinates), out_floats,)
return (masks_out, coordinates, out_floats,)
class StabilityAPI_SD3: class StabilityAPI_SD3:

View File

@ -170,7 +170,7 @@ function createSplineEditor(context, reset=false) {
const pointsWidget = context.widgets.find(w => w.name === "points_to_sample"); const pointsWidget = context.widgets.find(w => w.name === "points_to_sample");
const pointsStoreWidget = context.widgets.find(w => w.name === "points_store"); const pointsStoreWidget = context.widgets.find(w => w.name === "points_store");
const tensionWidget = context.widgets.find(w => w.name === "tension"); const tensionWidget = context.widgets.find(w => w.name === "tension");
const segmentedWidget = context.widgets.find(w => w.name === "segmented"); //const segmentedWidget = context.widgets.find(w => w.name === "segmented");
var interpolation = interpolationWidget.value var interpolation = interpolationWidget.value
var tension = tensionWidget.value var tension = tensionWidget.value