From 33ef9743709efcd4ded5e921c19d09fdcb80e390 Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 14 May 2024 12:12:09 +0300 Subject: [PATCH] Update mask_nodes.py --- nodes/mask_nodes.py | 51 +++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 4ed74cb..bbadd13 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -41,18 +41,20 @@ class BatchCLIPSeg: "blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), "opt_model": ("CLIPSEGMODEL", ), "prev_mask": ("MASK", {"default": None}), + "image_bg_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "invert": ("BOOLEAN", {"default": False}), } } CATEGORY = "KJNodes/masking" - RETURN_TYPES = ("MASK",) - RETURN_NAMES = ("Mask",) + RETURN_TYPES = ("MASK", "IMAGE", ) + RETURN_NAMES = ("Mask", "Image", ) FUNCTION = "segment_image" DESCRIPTION = """ Segments an image or batch of images using CLIPSeg. """ - def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None): + def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None, invert= False, image_bg_level=0.5): from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import torchvision.transforms as transforms offload_device = model_management.unet_offload_device() @@ -86,45 +88,54 @@ Segments an image or batch of images using CLIPSeg. autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device) with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): - images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ] + PIL_images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ] prompt = [text] * len(images) - input_prc = processor(text=prompt, images=images, return_tensors="pt") + input_prc = processor(text=prompt, images=PIL_images, return_tensors="pt") for key in input_prc: input_prc[key] = input_prc[key].to(device) outputs = self.model(**input_prc) - tensor = torch.sigmoid(outputs.logits) - tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) - tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float)) - - tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest') - tensor = tensor.squeeze(1) + mask_tensor = torch.sigmoid(outputs.logits) + mask_tensor = (mask_tensor - mask_tensor.min()) / (mask_tensor.max() - mask_tensor.min()) + mask_tensor = torch.where(mask_tensor > (threshold), mask_tensor, torch.tensor(0, dtype=torch.float)) + print(mask_tensor.shape) + if len(mask_tensor.shape) == 2: + mask_tensor = mask_tensor.unsqueeze(0) + mask_tensor = F.interpolate(mask_tensor.unsqueeze(1), size=(H, W), mode='nearest') + mask_tensor = mask_tensor.squeeze(1) self.model.to(offload_device) if binary_mask: - tensor = (tensor > 0).float() + mask_tensor = (mask_tensor > 0).float() if blur_sigma > 0: kernel_size = int(6 * int(blur_sigma) + 1) blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma)) - tensor = blur(tensor) + mask_tensor = blur(mask_tensor) if combine_mask: - tensor = torch.max(tensor, dim=0)[0] - tensor = tensor.unsqueeze(0).repeat(len(images),1,1) + mask_tensor = torch.max(mask_tensor, dim=0)[0] + mask_tensor = mask_tensor.unsqueeze(0).repeat(len(images),1,1) del outputs model_management.soft_empty_cache() if prev_mask is not None: - if prev_mask.shape != tensor.shape: + if prev_mask.shape != mask_tensor.shape: prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest') - tensor = tensor + prev_mask.to(device) - torch.clamp(tensor, min=0.0, max=1.0) + mask_tensor = mask_tensor + prev_mask.to(device) + torch.clamp(mask_tensor, min=0.0, max=1.0) - tensor = tensor.cpu().float() - return tensor, + if invert: + mask_tensor = 1 - mask_tensor + + image_tensor = images * mask_tensor.unsqueeze(-1) + (1 - mask_tensor.unsqueeze(-1)) * image_bg_level + image_tensor = torch.clamp(image_tensor, min=0.0, max=1.0).cpu().float() + + mask_tensor = mask_tensor.cpu().float() + + return mask_tensor, image_tensor, class DownloadAndLoadCLIPSeg: