mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-21 10:07:22 +08:00
optimize draw mask on image
This commit is contained in:
parent
b9abf5df31
commit
d1f6cf77b0
@ -1529,8 +1529,11 @@ class DrawMaskOnImage:
|
|||||||
"image": ("IMAGE", ),
|
"image": ("IMAGE", ),
|
||||||
"mask": ("MASK", ),
|
"mask": ("MASK", ),
|
||||||
"color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255 or 0.0-1.0, separated by commas."}),
|
"color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255 or 0.0-1.0, separated by commas."}),
|
||||||
}
|
},
|
||||||
|
"optional": {
|
||||||
|
"device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}),
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", )
|
RETURN_TYPES = ("IMAGE", )
|
||||||
RETURN_NAMES = ("images",)
|
RETURN_NAMES = ("images",)
|
||||||
@ -1538,11 +1541,14 @@ class DrawMaskOnImage:
|
|||||||
CATEGORY = "KJNodes/masking"
|
CATEGORY = "KJNodes/masking"
|
||||||
DESCRIPTION = "Applies the provided masks to the input images."
|
DESCRIPTION = "Applies the provided masks to the input images."
|
||||||
|
|
||||||
def apply(self, image, mask, color):
|
def apply(self, image, mask, color, device="cpu"):
|
||||||
B, H, W, C = image.shape
|
B, H, W, C = image.shape
|
||||||
BM, HM, WM = mask.shape
|
BM, HM, WM = mask.shape
|
||||||
|
|
||||||
in_masks = mask.clone()
|
processing_device = main_device if device == "gpu" else torch.device("cpu")
|
||||||
|
|
||||||
|
in_masks = mask.clone().to(processing_device)
|
||||||
|
in_images = image.clone().to(processing_device)
|
||||||
|
|
||||||
if HM != H or WM != W:
|
if HM != H or WM != W:
|
||||||
in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
|
in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
|
||||||
@ -1562,12 +1568,12 @@ class DrawMaskOnImage:
|
|||||||
else:
|
else:
|
||||||
bg_values.append(int(val_str) / 255.0)
|
bg_values.append(int(val_str) / 255.0)
|
||||||
|
|
||||||
background_color = torch.tensor(bg_values, dtype=torch.float32, device=image.device)
|
background_color = torch.tensor(bg_values, dtype=torch.float32, device=in_images.device)
|
||||||
|
|
||||||
for i in range(B):
|
for i in tqdm(range(B), desc="DrawMaskOnImage batch"):
|
||||||
curr_mask = in_masks[i]
|
curr_mask = in_masks[i]
|
||||||
img_idx = min(i, B - 1)
|
img_idx = min(i, B - 1)
|
||||||
curr_image = image[img_idx]
|
curr_image = in_images[img_idx]
|
||||||
mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3)
|
mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3)
|
||||||
masked_image = curr_image * (1 - mask_expanded) + background_color * (mask_expanded)
|
masked_image = curr_image * (1 - mask_expanded) + background_color * (mask_expanded)
|
||||||
output_images.append(masked_image)
|
output_images.append(masked_image)
|
||||||
@ -1576,7 +1582,7 @@ class DrawMaskOnImage:
|
|||||||
if not output_images:
|
if not output_images:
|
||||||
return (torch.zeros((0, H, W, 3), dtype=image.dtype),)
|
return (torch.zeros((0, H, W, 3), dtype=image.dtype),)
|
||||||
|
|
||||||
out_rgb = torch.stack(output_images, dim=0)
|
out_rgb = torch.stack(output_images, dim=0).cpu()
|
||||||
|
|
||||||
return (out_rgb, )
|
return (out_rgb, )
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user