diff --git a/__init__.py b/__init__.py index bd82a69..bba36b7 100644 --- a/__init__.py +++ b/__init__.py @@ -19,6 +19,7 @@ NODE_CONFIG = { "ConditioningSetMaskAndCombine5": {"class": ConditioningSetMaskAndCombine5, "name": "ConditioningSetMaskAndCombine5"}, "CondPassThrough": {"class": CondPassThrough}, #masking + "DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"}, "BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"}, "ColorToMask": {"class": ColorToMask, "name": "Color To Mask"}, "CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"}, diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 44ea4ff..347212c 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -31,7 +31,7 @@ class BatchCLIPSeg: { "images": ("IMAGE",), "text": ("STRING", {"multiline": False}), - "threshold": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.001}), + "threshold": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 10.0, "step": 0.001}), "binary_mask": ("BOOLEAN", {"default": True}), "combine_mask": ("BOOLEAN", {"default": False}), "use_cuda": ("BOOLEAN", {"default": True}), @@ -39,6 +39,8 @@ class BatchCLIPSeg: "optional": { "blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "opt_model": ("CLIPSEGMODEL", ), + "prev_mask": ("MASK", {"default": None}), } } @@ -50,7 +52,7 @@ class BatchCLIPSeg: Segments an image or batch of images using CLIPSeg. """ - def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0): + def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None): from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import torchvision.transforms as transforms offload_device = model_management.unet_offload_device() @@ -59,10 +61,23 @@ Segments an image or batch of images using CLIPSeg. else: device = torch.device("cpu") dtype = model_management.unet_dtype() - if not hasattr(self, "model"): - self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") - - processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + + if opt_model is None: + checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', 'clipseg-rd64-refined-fp16') + if not hasattr(self, "model"): + try: + if not os.path.exists(checkpoint_path): + from huggingface_hub import snapshot_download + snapshot_download(repo_id="Kijai/clipseg-rd64-refined-fp16", local_dir=checkpoint_path, local_dir_use_symlinks=False) + self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) + except: + checkpoint_path = "CIDAS/clipseg-rd64-refined" + self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) + processor = CLIPSegProcessor.from_pretrained(checkpoint_path) + + else: + self.model = opt_model['model'] + processor = opt_model['processor'] self.model.to(dtype).to(device) @@ -81,20 +96,20 @@ Segments an image or batch of images using CLIPSeg. outputs = self.model(**input_prc) tensor = torch.sigmoid(outputs.logits) - tensor = torch.where(tensor > (threshold / 10), tensor, torch.tensor(0, dtype=torch.float)) + print(tensor.min(), tensor.max()) tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) + tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float)) + tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest') tensor = tensor.squeeze(1) self.model.to(offload_device) - results = tensor.cpu().float() - print(results.min(), results.max()) if binary_mask: tensor = (tensor > 0).float() if blur_sigma > 0: - kernel_size = int(6 * blur_sigma + 1) + kernel_size = int(6 * int(blur_sigma) + 1) blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma)) tensor = blur(tensor) @@ -105,8 +120,58 @@ Segments an image or batch of images using CLIPSeg. del outputs model_management.soft_empty_cache() + if prev_mask is not None: + tensor = tensor + prev_mask + torch.clamp(tensor, min=0.0, max=1.0) + return tensor, +class DownloadAndLoadCLIPSeg: + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + + return {"required": + { + "model": ( + [ 'Kijai/clipseg-rd64-refined-fp16', + 'CIDAS/clipseg-rd64-refined', + ], + { + "default": 'clipseg-rd64-refined-fp16' + }), + }, + } + + CATEGORY = "KJNodes/masking" + RETURN_TYPES = ("CLIPSEGMODEL",) + RETURN_NAMES = ("clipseg_model",) + FUNCTION = "segment_image" + DESCRIPTION = """ +Downloads and loads CLIPSeg model with huggingface_hub, +to ComfyUI/models/clip_seg +""" + + def segment_image(self, model): + from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation + checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', model) + if not hasattr(self, "model"): + if not os.path.exists(checkpoint_path): + from huggingface_hub import snapshot_download + snapshot_download(repo_id=model, local_dir=checkpoint_path.split("/")[-1], local_dir_use_symlinks=False) + self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) + + processor = CLIPSegProcessor.from_pretrained(checkpoint_path) + + clipseg_model = {} + clipseg_model['model'] = self.model + clipseg_model['processor'] = processor + + return clipseg_model, + class CreateTextMask: RETURN_TYPES = ("IMAGE", "MASK",) diff --git a/requirements.txt b/requirements.txt index 120143b..bd6369e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pillow>=10.3.0 scipy color-matcher matplotlib +huggingface_hub \ No newline at end of file