From 6e1fa8d37a529f5083c12944f48926309f66aa68 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 19 Apr 2024 16:55:29 +0300 Subject: [PATCH] Add FloatToMask --- nodes.py | 44 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 6e46ccd..863bcdd 100644 --- a/nodes.py +++ b/nodes.py @@ -4822,6 +4822,46 @@ and returns it as a float value. return pd.Series(mean_values), else: raise ValueError(f"Unsupported output_type: {output_type}") +class FloatToMask: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input_values": ("FLOAT", {"forceInput": True, "default": 0}), + "width": ("INT", {"default": 100, "min": 1}), + "height": ("INT", {"default": 100, "min": 1}), + }, + } + RETURN_TYPES = ("MASK",) + FUNCTION = "execute" + CATEGORY = "KJNodes" + DESCRIPTION = """ +Generates a batch of masks based on the input float values. +The batch size is determined by the length of the input float values. +Each mask is generated with the specified width and height. +""" + + def execute(self, input_values, width, height): + import pandas as pd + # Ensure input_values is a list + if isinstance(input_values, (float, int)): + input_values = [input_values] + elif isinstance(input_values, pd.Series): + input_values = input_values.tolist() + elif isinstance(input_values, list) and all(isinstance(item, list) for item in input_values): + input_values = [item for sublist in input_values for item in sublist] + + # Generate a batch of masks based on the input_values + masks = [] + for value in input_values: + # Assuming value is a float between 0 and 1 representing the mask's intensity + mask = torch.ones((height, width), dtype=torch.float32) * value + masks.append(mask) + masks_out = torch.stack(masks, dim=0) + print(masks_out.shape) + return(masks_out,) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, @@ -4904,7 +4944,8 @@ NODE_CLASS_MAPPINGS = { "SplineEditor": SplineEditor, "ImageAndMaskPreview": ImageAndMaskPreview, "StabilityAPI_SD3": StabilityAPI_SD3, - "MaskOrImageToWeight": MaskOrImageToWeight + "MaskOrImageToWeight": MaskOrImageToWeight, + "FloatToMask": FloatToMask } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -4988,4 +5029,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageAndMaskPreview": "Image & Mask Preview", "StabilityAPI_SD3": "Stability API SD3", "MaskOrImageToWeight": "Mask Or Image To Weight", + "FloatToMask": "Float To Mask", } \ No newline at end of file