Add WeightScheduleExtend -node

This commit is contained in:
kijai 2024-04-27 11:50:24 +03:00
parent ae0eb9be67
commit b90c8d830f
2 changed files with 83 additions and 3 deletions

View File

@ -88,7 +88,8 @@ NODE_CLASS_MAPPINGS = {
"CustomSigmas": CustomSigmas,
"ImagePass": ImagePass,
"SplineEditor": SplineEditor,
"CreateShapeMaskOnPath": CreateShapeMaskOnPath
"CreateShapeMaskOnPath": CreateShapeMaskOnPath,
"WeightScheduleExtend": WeightScheduleExtend
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -179,7 +180,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CustomSigmas": "Custom Sigmas",
"ImagePass": "ImagePass",
"SplineEditor": "Spline Editor",
"CreateShapeMaskOnPath": "Create Shape Mask On Path"
"CreateShapeMaskOnPath": "Create Shape Mask On Path",
"WeightScheduleExtend": "Weight Schedule Extend"
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]

View File

@ -397,3 +397,81 @@ Each mask is generated with the specified width and height.
masks_out = torch.stack(masks, dim=0)
return(masks_out,)
class WeightScheduleExtend:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_values_1": ("FLOAT", {"default": 0.0, "forceInput": True}),
"input_values_2": ("FLOAT", {"default": 0.0, "forceInput": True}),
"output_type": (
[
'match_input',
'list',
'list of lists',
'pandas series',
'tensor',
],
{
"default": 'match_input'
}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
CATEGORY = "KJNodes"
DESCRIPTION = """
Converts different value lists/series to another type.
"""
def detect_input_type(self, input_values):
import pandas as pd
if isinstance(input_values, list):
return 'list'
elif isinstance(input_values, pd.Series):
return 'pandas series'
elif isinstance(input_values, torch.Tensor):
return 'tensor'
elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values):
return 'list of lists'
else:
raise ValueError("Unsupported input type")
def execute(self, input_values_1, input_values_2, output_type):
import pandas as pd
input_type_1 = self.detect_input_type(input_values_1)
input_type_2 = self.detect_input_type(input_values_2)
# Convert input_values_2 to the same format as input_values_1 if they do not match
if not input_type_1 == input_type_2:
print("Converting input_values_2 to the same format as input_values_1")
if input_type_1 == 'list of lists':
# Assuming input_values_2 is a flat list, convert it to a list of lists
float_values_2 = [[item] for item in input_values_2]
elif input_type_1 == 'pandas series':
# Convert input_values_2 to a pandas Series
float_values_2 = pd.Series(input_values_2)
elif input_type_1 == 'tensor':
# Convert input_values_2 to a tensor
float_values_2 = torch.tensor(input_values_2, dtype=torch.float32)
else:
print("Input types match, no conversion needed")
# If the types match, no conversion is needed
float_values_2 = input_values_2
float_values = input_values_1 + float_values_2
if output_type == 'list':
return float_values,
elif output_type == 'list of lists':
return [[value] for value in float_values],
elif output_type == 'pandas series':
return pd.Series(float_values),
elif output_type == 'tensor':
if input_type_1 == 'pandas series':
return torch.tensor(input_values_1.values, dtype=torch.float32),
elif output_type == 'match_input':
return float_values,
else:
raise ValueError(f"Unsupported output_type: {output_type}")