From b90c8d830f34d3017373275017167eef7db3ac5f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 27 Apr 2024 11:50:24 +0300 Subject: [PATCH] Add WeightScheduleExtend -node --- __init__.py | 6 ++-- curve_nodes.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/__init__.py b/__init__.py index f2a02ae..586f2c5 100644 --- a/__init__.py +++ b/__init__.py @@ -88,7 +88,8 @@ NODE_CLASS_MAPPINGS = { "CustomSigmas": CustomSigmas, "ImagePass": ImagePass, "SplineEditor": SplineEditor, - "CreateShapeMaskOnPath": CreateShapeMaskOnPath + "CreateShapeMaskOnPath": CreateShapeMaskOnPath, + "WeightScheduleExtend": WeightScheduleExtend } NODE_DISPLAY_NAME_MAPPINGS = { @@ -179,7 +180,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CustomSigmas": "Custom Sigmas", "ImagePass": "ImagePass", "SplineEditor": "Spline Editor", - "CreateShapeMaskOnPath": "Create Shape Mask On Path" + "CreateShapeMaskOnPath": "Create Shape Mask On Path", + "WeightScheduleExtend": "Weight Schedule Extend" } __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] diff --git a/curve_nodes.py b/curve_nodes.py index 0c45454..0484a70 100644 --- a/curve_nodes.py +++ b/curve_nodes.py @@ -396,4 +396,82 @@ Each mask is generated with the specified width and height. masks.append(mask) masks_out = torch.stack(masks, dim=0) - return(masks_out,) \ No newline at end of file + return(masks_out,) +class WeightScheduleExtend: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input_values_1": ("FLOAT", {"default": 0.0, "forceInput": True}), + "input_values_2": ("FLOAT", {"default": 0.0, "forceInput": True}), + "output_type": ( + [ + 'match_input', + 'list', + 'list of lists', + 'pandas series', + 'tensor', + ], + { + "default": 'match_input' + }), + }, + + } + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + CATEGORY = "KJNodes" + DESCRIPTION = """ +Converts different value lists/series to another type. +""" + + def detect_input_type(self, input_values): + import pandas as pd + if isinstance(input_values, list): + return 'list' + elif isinstance(input_values, pd.Series): + return 'pandas series' + elif isinstance(input_values, torch.Tensor): + return 'tensor' + elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values): + return 'list of lists' + else: + raise ValueError("Unsupported input type") + + def execute(self, input_values_1, input_values_2, output_type): + import pandas as pd + input_type_1 = self.detect_input_type(input_values_1) + input_type_2 = self.detect_input_type(input_values_2) + # Convert input_values_2 to the same format as input_values_1 if they do not match + if not input_type_1 == input_type_2: + print("Converting input_values_2 to the same format as input_values_1") + if input_type_1 == 'list of lists': + # Assuming input_values_2 is a flat list, convert it to a list of lists + float_values_2 = [[item] for item in input_values_2] + elif input_type_1 == 'pandas series': + # Convert input_values_2 to a pandas Series + float_values_2 = pd.Series(input_values_2) + elif input_type_1 == 'tensor': + # Convert input_values_2 to a tensor + float_values_2 = torch.tensor(input_values_2, dtype=torch.float32) + else: + print("Input types match, no conversion needed") + # If the types match, no conversion is needed + float_values_2 = input_values_2 + + float_values = input_values_1 + float_values_2 + + if output_type == 'list': + return float_values, + elif output_type == 'list of lists': + return [[value] for value in float_values], + elif output_type == 'pandas series': + return pd.Series(float_values), + elif output_type == 'tensor': + if input_type_1 == 'pandas series': + return torch.tensor(input_values_1.values, dtype=torch.float32), + elif output_type == 'match_input': + return float_values, + else: + raise ValueError(f"Unsupported output_type: {output_type}") \ No newline at end of file