Add DrawMaskOnImage

This commit is contained in:
kijai 2025-09-18 18:05:34 +03:00
parent c2712d7781
commit 8d7a3a3bab
2 changed files with 59 additions and 0 deletions

View File

@ -22,6 +22,7 @@ NODE_CONFIG = {
"ConditioningSetMaskAndCombine5": {"class": ConditioningSetMaskAndCombine5, "name": "ConditioningSetMaskAndCombine5"},
"CondPassThrough": {"class": CondPassThrough},
#masking
"DrawMaskOnImage": {"class": DrawMaskOnImage, "name": "Draw Mask On Image"},
"DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"},
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},

View File

@ -1502,3 +1502,61 @@ class ConsolidateMasksKJ:
print(f"Consolidated {B} masks into {len(final_masks)}")
return (torch.stack(final_masks, dim=0),)
class DrawMaskOnImage:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE", ),
"mask": ("MASK", ),
"color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255 or 0.0-1.0, separated by commas."}),
}
}
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images",)
FUNCTION = "apply"
CATEGORY = "KJNodes/image"
DESCRIPTION = "Applies the provided masks to the input images."
def apply(self, image, mask, color):
B, H, W, C = image.shape
BM, HM, WM = mask.shape
in_masks = mask.clone()
if HM != H or WM != W:
in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
if B > BM:
in_masks = in_masks.repeat((B + BM - 1) // BM, 1, 1)[:B]
elif BM > B:
in_masks = in_masks[:B]
output_images = []
# Parse background color - detect if values are integers or floats
bg_values = []
for x in color.split(","):
val_str = x.strip()
if '.' in val_str:
bg_values.append(float(val_str))
else:
bg_values.append(int(val_str) / 255.0)
background_color = torch.tensor(bg_values, dtype=torch.float32, device=image.device)
for i in range(B):
curr_mask = in_masks[i]
img_idx = min(i, B - 1)
curr_image = image[img_idx]
mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3)
masked_image = curr_image * (1 - mask_expanded) + background_color * (mask_expanded)
output_images.append(masked_image)
# If no masks were processed, return empty tensor
if not output_images:
return (torch.zeros((0, H, W, 3), dtype=image.dtype),)
out_rgb = torch.stack(output_images, dim=0)
return (out_rgb, )