Add FloatToMask

This commit is contained in:
kijai 2024-04-19 16:55:29 +03:00
parent 85724229f4
commit 6e1fa8d37a

View File

@ -4822,6 +4822,46 @@ and returns it as a float value.
return pd.Series(mean_values),
else:
raise ValueError(f"Unsupported output_type: {output_type}")
class FloatToMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_values": ("FLOAT", {"forceInput": True, "default": 0}),
"width": ("INT", {"default": 100, "min": 1}),
"height": ("INT", {"default": 100, "min": 1}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "execute"
CATEGORY = "KJNodes"
DESCRIPTION = """
Generates a batch of masks based on the input float values.
The batch size is determined by the length of the input float values.
Each mask is generated with the specified width and height.
"""
def execute(self, input_values, width, height):
import pandas as pd
# Ensure input_values is a list
if isinstance(input_values, (float, int)):
input_values = [input_values]
elif isinstance(input_values, pd.Series):
input_values = input_values.tolist()
elif isinstance(input_values, list) and all(isinstance(item, list) for item in input_values):
input_values = [item for sublist in input_values for item in sublist]
# Generate a batch of masks based on the input_values
masks = []
for value in input_values:
# Assuming value is a float between 0 and 1 representing the mask's intensity
mask = torch.ones((height, width), dtype=torch.float32) * value
masks.append(mask)
masks_out = torch.stack(masks, dim=0)
print(masks_out.shape)
return(masks_out,)
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
@ -4904,7 +4944,8 @@ NODE_CLASS_MAPPINGS = {
"SplineEditor": SplineEditor,
"ImageAndMaskPreview": ImageAndMaskPreview,
"StabilityAPI_SD3": StabilityAPI_SD3,
"MaskOrImageToWeight": MaskOrImageToWeight
"MaskOrImageToWeight": MaskOrImageToWeight,
"FloatToMask": FloatToMask
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -4988,4 +5029,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageAndMaskPreview": "Image & Mask Preview",
"StabilityAPI_SD3": "Stability API SD3",
"MaskOrImageToWeight": "Mask Or Image To Weight",
"FloatToMask": "Float To Mask",
}