Add WeightScheduleConvert

This commit is contained in:
kijai 2024-04-22 09:50:15 +03:00
parent e9b8af8fd0
commit 9b25146064
2 changed files with 77 additions and 7 deletions

View File

@ -386,8 +386,8 @@ and interpolating from that to fully black at the 16th frame.
"points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}),
"invert": ("BOOLEAN", {"default": False}),
"frames": ("INT", {"default": 16,"min": 2, "max": 255, "step": 1}),
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"width": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}),
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
},
}
@ -4906,8 +4906,7 @@ class MaskOrImageToWeight:
FUNCTION = "execute"
CATEGORY = "KJNodes"
DESCRIPTION = """
Gets the mean value of mask or image
and returns it as a float value.
Converts different value lists/series to another type.
"""
def execute(self, output_type, images=None, masks=None):
@ -4915,7 +4914,6 @@ and returns it as a float value.
if masks is not None and images is None:
for mask in masks:
mean_values.append(mask.mean().item())
print(mean_values)
elif masks is None and images is not None:
for image in images:
mean_values.append(image.mean().item())
@ -4934,9 +4932,79 @@ and returns it as a float value.
raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type")
return pd.Series(mean_values),
elif output_type == 'tensor':
return torch.tensor(mean_values, dtype=torch.float32)
return torch.tensor(mean_values, dtype=torch.float32),
else:
raise ValueError(f"Unsupported output_type: {output_type}")
class WeightScheduleConvert:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_values": ("FLOAT", {"default": 0.0, "forceInput": True}),
"output_type": (
[
'list',
'list of lists',
'pandas series',
'tensor',
],
{
"default": 'list'
}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
CATEGORY = "KJNodes"
DESCRIPTION = """
Gets the mean value of mask or image
and returns it as a float value.
"""
def detect_input_type(self, input_values):
import pandas as pd
if isinstance(input_values, list):
return 'list'
elif isinstance(input_values, pd.Series):
return 'pandas series'
elif isinstance(input_values, torch.Tensor):
return 'tensor'
elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values):
return 'list of lists'
else:
raise ValueError("Unsupported input type")
def execute(self, input_values, output_type):
import pandas as pd
# Detect the input type
input_type = self.detect_input_type(input_values)
# Convert input_values to a list of floats
if input_type == 'list of lists':
float_values = [item for sublist in input_values for item in sublist]
elif input_type == 'pandas series':
float_values = input_values.tolist()
elif input_type == 'tensor':
float_values = input_values
else:
float_values = input_values
if output_type == 'list':
return float_values,
elif output_type == 'list of lists':
return [[value] for value in float_values],
elif output_type == 'pandas series':
return pd.Series(float_values),
elif output_type == 'tensor':
if input_type == 'pandas series':
return torch.tensor(input_values.values, dtype=torch.float32),
else:
raise ValueError(f"Unsupported output_type: {output_type}")
class FloatToMask:
@classmethod
@ -5060,6 +5128,7 @@ NODE_CLASS_MAPPINGS = {
"ImageAndMaskPreview": ImageAndMaskPreview,
"StabilityAPI_SD3": StabilityAPI_SD3,
"MaskOrImageToWeight": MaskOrImageToWeight,
"WeightScheduleConvert": WeightScheduleConvert,
"FloatToMask": FloatToMask,
"CustomSigmas": CustomSigmas
}
@ -5145,6 +5214,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageAndMaskPreview": "Image & Mask Preview",
"StabilityAPI_SD3": "Stability API SD3",
"MaskOrImageToWeight": "Mask Or Image To Weight",
"WeightScheduleConvert": "Weight Schedule Convert",
"FloatToMask": "Float To Mask",
"CustomSigmas": "Custom Sigmas",
}

View File

@ -123,7 +123,7 @@ app.registerExtension({
createSplineEditor(this, true)
}
});
this.setSize([550, 800])
this.setSize([550, 850])
this.splineEditor.parentEl = document.createElement("div");
this.splineEditor.parentEl.className = "spline-editor";
this.splineEditor.parentEl.id = `spline-editor-${this.uuid}`