diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index ab89998..44ea4ff 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -36,6 +36,10 @@ class BatchCLIPSeg: "combine_mask": ("BOOLEAN", {"default": False}), "use_cuda": ("BOOLEAN", {"default": True}), }, + "optional": + { + "blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), + } } CATEGORY = "KJNodes/masking" @@ -46,55 +50,62 @@ class BatchCLIPSeg: Segments an image or batch of images using CLIPSeg. """ - def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda): + def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0): from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation - height, width, _ = images[0].shape + import torchvision.transforms as transforms + offload_device = model_management.unet_offload_device() if use_cuda and torch.cuda.is_available(): - device = torch.device("cuda") + device = model_management.get_torch_device() else: device = torch.device("cpu") dtype = model_management.unet_dtype() - model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") - model.to(dtype) - model.to(device) - images = images.to(device) + if not hasattr(self, "model"): + self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") + processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + + self.model.to(dtype).to(device) + + B, H, W, C = images.shape + images = images.to(device) + autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device) with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ] prompt = [text] * len(images) input_prc = processor(text=prompt, images=images, return_tensors="pt") - # Move the processed input to the device + for key in input_prc: input_prc[key] = input_prc[key].to(device) + outputs = self.model(**input_prc) - outputs = model(**input_prc) + tensor = torch.sigmoid(outputs.logits) + tensor = torch.where(tensor > (threshold / 10), tensor, torch.tensor(0, dtype=torch.float)) + tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) - tensor = torch.sigmoid(outputs.logits) - tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) - tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) - tensor = tensor_normalized + tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest') + tensor = tensor.squeeze(1) - # Resize the mask - resized_tensor = F.interpolate(tensor.unsqueeze(1), size=(height, width), mode='nearest') - - # Remove the extra dimensions - resized_tensor = resized_tensor.squeeze(1) - - results = resized_tensor.cpu().float() + self.model.to(offload_device) + results = tensor.cpu().float() + print(results.min(), results.max()) + + if binary_mask: + tensor = (tensor > 0).float() + if blur_sigma > 0: + kernel_size = int(6 * blur_sigma + 1) + blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma)) + tensor = blur(tensor) if combine_mask: - combined_results = torch.max(results, dim=0)[0] - results = combined_results.unsqueeze(0).repeat(len(images),1,1) + tensor = torch.max(tensor, dim=0)[0] + tensor = tensor.unsqueeze(0).repeat(len(images),1,1) - if binary_mask: - results = results.round() + del outputs + model_management.soft_empty_cache() - del outputs, tensor, tensor_thresholded, tensor_normalized, resized_tensor, images - torch.cuda.empty_cache() - - return results, + return tensor, class CreateTextMask: