mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Add WeightScheduleExtend -node
This commit is contained in:
parent
ae0eb9be67
commit
b90c8d830f
@ -88,7 +88,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CustomSigmas": CustomSigmas,
|
||||
"ImagePass": ImagePass,
|
||||
"SplineEditor": SplineEditor,
|
||||
"CreateShapeMaskOnPath": CreateShapeMaskOnPath
|
||||
"CreateShapeMaskOnPath": CreateShapeMaskOnPath,
|
||||
"WeightScheduleExtend": WeightScheduleExtend
|
||||
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -179,7 +180,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CustomSigmas": "Custom Sigmas",
|
||||
"ImagePass": "ImagePass",
|
||||
"SplineEditor": "Spline Editor",
|
||||
"CreateShapeMaskOnPath": "Create Shape Mask On Path"
|
||||
"CreateShapeMaskOnPath": "Create Shape Mask On Path",
|
||||
"WeightScheduleExtend": "Weight Schedule Extend"
|
||||
}
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
|
||||
|
||||
|
||||
@ -396,4 +396,82 @@ Each mask is generated with the specified width and height.
|
||||
masks.append(mask)
|
||||
masks_out = torch.stack(masks, dim=0)
|
||||
|
||||
return(masks_out,)
|
||||
return(masks_out,)
|
||||
class WeightScheduleExtend:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"input_values_1": ("FLOAT", {"default": 0.0, "forceInput": True}),
|
||||
"input_values_2": ("FLOAT", {"default": 0.0, "forceInput": True}),
|
||||
"output_type": (
|
||||
[
|
||||
'match_input',
|
||||
'list',
|
||||
'list of lists',
|
||||
'pandas series',
|
||||
'tensor',
|
||||
],
|
||||
{
|
||||
"default": 'match_input'
|
||||
}),
|
||||
},
|
||||
|
||||
}
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "KJNodes"
|
||||
DESCRIPTION = """
|
||||
Converts different value lists/series to another type.
|
||||
"""
|
||||
|
||||
def detect_input_type(self, input_values):
|
||||
import pandas as pd
|
||||
if isinstance(input_values, list):
|
||||
return 'list'
|
||||
elif isinstance(input_values, pd.Series):
|
||||
return 'pandas series'
|
||||
elif isinstance(input_values, torch.Tensor):
|
||||
return 'tensor'
|
||||
elif isinstance(input_values, list) and all(isinstance(sub, list) for sub in input_values):
|
||||
return 'list of lists'
|
||||
else:
|
||||
raise ValueError("Unsupported input type")
|
||||
|
||||
def execute(self, input_values_1, input_values_2, output_type):
|
||||
import pandas as pd
|
||||
input_type_1 = self.detect_input_type(input_values_1)
|
||||
input_type_2 = self.detect_input_type(input_values_2)
|
||||
# Convert input_values_2 to the same format as input_values_1 if they do not match
|
||||
if not input_type_1 == input_type_2:
|
||||
print("Converting input_values_2 to the same format as input_values_1")
|
||||
if input_type_1 == 'list of lists':
|
||||
# Assuming input_values_2 is a flat list, convert it to a list of lists
|
||||
float_values_2 = [[item] for item in input_values_2]
|
||||
elif input_type_1 == 'pandas series':
|
||||
# Convert input_values_2 to a pandas Series
|
||||
float_values_2 = pd.Series(input_values_2)
|
||||
elif input_type_1 == 'tensor':
|
||||
# Convert input_values_2 to a tensor
|
||||
float_values_2 = torch.tensor(input_values_2, dtype=torch.float32)
|
||||
else:
|
||||
print("Input types match, no conversion needed")
|
||||
# If the types match, no conversion is needed
|
||||
float_values_2 = input_values_2
|
||||
|
||||
float_values = input_values_1 + float_values_2
|
||||
|
||||
if output_type == 'list':
|
||||
return float_values,
|
||||
elif output_type == 'list of lists':
|
||||
return [[value] for value in float_values],
|
||||
elif output_type == 'pandas series':
|
||||
return pd.Series(float_values),
|
||||
elif output_type == 'tensor':
|
||||
if input_type_1 == 'pandas series':
|
||||
return torch.tensor(input_values_1.values, dtype=torch.float32),
|
||||
elif output_type == 'match_input':
|
||||
return float_values,
|
||||
else:
|
||||
raise ValueError(f"Unsupported output_type: {output_type}")
|
||||
Loading…
x
Reference in New Issue
Block a user