mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-03-16 14:27:05 +08:00
clipseg improvements
This commit is contained in:
parent
4812eff6e5
commit
17a6b358af
@ -19,6 +19,7 @@ NODE_CONFIG = {
|
|||||||
"ConditioningSetMaskAndCombine5": {"class": ConditioningSetMaskAndCombine5, "name": "ConditioningSetMaskAndCombine5"},
|
"ConditioningSetMaskAndCombine5": {"class": ConditioningSetMaskAndCombine5, "name": "ConditioningSetMaskAndCombine5"},
|
||||||
"CondPassThrough": {"class": CondPassThrough},
|
"CondPassThrough": {"class": CondPassThrough},
|
||||||
#masking
|
#masking
|
||||||
|
"DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"},
|
||||||
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
|
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
|
||||||
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},
|
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},
|
||||||
"CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"},
|
"CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"},
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class BatchCLIPSeg:
|
|||||||
{
|
{
|
||||||
"images": ("IMAGE",),
|
"images": ("IMAGE",),
|
||||||
"text": ("STRING", {"multiline": False}),
|
"text": ("STRING", {"multiline": False}),
|
||||||
"threshold": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.001}),
|
"threshold": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 10.0, "step": 0.001}),
|
||||||
"binary_mask": ("BOOLEAN", {"default": True}),
|
"binary_mask": ("BOOLEAN", {"default": True}),
|
||||||
"combine_mask": ("BOOLEAN", {"default": False}),
|
"combine_mask": ("BOOLEAN", {"default": False}),
|
||||||
"use_cuda": ("BOOLEAN", {"default": True}),
|
"use_cuda": ("BOOLEAN", {"default": True}),
|
||||||
@ -39,6 +39,8 @@ class BatchCLIPSeg:
|
|||||||
"optional":
|
"optional":
|
||||||
{
|
{
|
||||||
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||||
|
"opt_model": ("CLIPSEGMODEL", ),
|
||||||
|
"prev_mask": ("MASK", {"default": None}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,7 +52,7 @@ class BatchCLIPSeg:
|
|||||||
Segments an image or batch of images using CLIPSeg.
|
Segments an image or batch of images using CLIPSeg.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0):
|
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None):
|
||||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
@ -59,10 +61,23 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = model_management.unet_dtype()
|
dtype = model_management.unet_dtype()
|
||||||
if not hasattr(self, "model"):
|
|
||||||
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
if opt_model is None:
|
||||||
|
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', 'clipseg-rd64-refined-fp16')
|
||||||
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
if not hasattr(self, "model"):
|
||||||
|
try:
|
||||||
|
if not os.path.exists(checkpoint_path):
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
snapshot_download(repo_id="Kijai/clipseg-rd64-refined-fp16", local_dir=checkpoint_path, local_dir_use_symlinks=False)
|
||||||
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
|
||||||
|
except:
|
||||||
|
checkpoint_path = "CIDAS/clipseg-rd64-refined"
|
||||||
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
|
||||||
|
processor = CLIPSegProcessor.from_pretrained(checkpoint_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.model = opt_model['model']
|
||||||
|
processor = opt_model['processor']
|
||||||
|
|
||||||
self.model.to(dtype).to(device)
|
self.model.to(dtype).to(device)
|
||||||
|
|
||||||
@ -81,20 +96,20 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
outputs = self.model(**input_prc)
|
outputs = self.model(**input_prc)
|
||||||
|
|
||||||
tensor = torch.sigmoid(outputs.logits)
|
tensor = torch.sigmoid(outputs.logits)
|
||||||
tensor = torch.where(tensor > (threshold / 10), tensor, torch.tensor(0, dtype=torch.float))
|
print(tensor.min(), tensor.max())
|
||||||
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
||||||
|
tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float))
|
||||||
|
|
||||||
|
|
||||||
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
||||||
tensor = tensor.squeeze(1)
|
tensor = tensor.squeeze(1)
|
||||||
|
|
||||||
self.model.to(offload_device)
|
self.model.to(offload_device)
|
||||||
results = tensor.cpu().float()
|
|
||||||
print(results.min(), results.max())
|
|
||||||
|
|
||||||
if binary_mask:
|
if binary_mask:
|
||||||
tensor = (tensor > 0).float()
|
tensor = (tensor > 0).float()
|
||||||
if blur_sigma > 0:
|
if blur_sigma > 0:
|
||||||
kernel_size = int(6 * blur_sigma + 1)
|
kernel_size = int(6 * int(blur_sigma) + 1)
|
||||||
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
|
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
|
||||||
tensor = blur(tensor)
|
tensor = blur(tensor)
|
||||||
|
|
||||||
@ -105,8 +120,58 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
del outputs
|
del outputs
|
||||||
model_management.soft_empty_cache()
|
model_management.soft_empty_cache()
|
||||||
|
|
||||||
|
if prev_mask is not None:
|
||||||
|
tensor = tensor + prev_mask
|
||||||
|
torch.clamp(tensor, min=0.0, max=1.0)
|
||||||
|
|
||||||
return tensor,
|
return tensor,
|
||||||
|
|
||||||
|
class DownloadAndLoadCLIPSeg:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
|
||||||
|
return {"required":
|
||||||
|
{
|
||||||
|
"model": (
|
||||||
|
[ 'Kijai/clipseg-rd64-refined-fp16',
|
||||||
|
'CIDAS/clipseg-rd64-refined',
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"default": 'clipseg-rd64-refined-fp16'
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
RETURN_TYPES = ("CLIPSEGMODEL",)
|
||||||
|
RETURN_NAMES = ("clipseg_model",)
|
||||||
|
FUNCTION = "segment_image"
|
||||||
|
DESCRIPTION = """
|
||||||
|
Downloads and loads CLIPSeg model with huggingface_hub,
|
||||||
|
to ComfyUI/models/clip_seg
|
||||||
|
"""
|
||||||
|
|
||||||
|
def segment_image(self, model):
|
||||||
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||||
|
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', model)
|
||||||
|
if not hasattr(self, "model"):
|
||||||
|
if not os.path.exists(checkpoint_path):
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
snapshot_download(repo_id=model, local_dir=checkpoint_path.split("/")[-1], local_dir_use_symlinks=False)
|
||||||
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
|
||||||
|
|
||||||
|
processor = CLIPSegProcessor.from_pretrained(checkpoint_path)
|
||||||
|
|
||||||
|
clipseg_model = {}
|
||||||
|
clipseg_model['model'] = self.model
|
||||||
|
clipseg_model['processor'] = processor
|
||||||
|
|
||||||
|
return clipseg_model,
|
||||||
|
|
||||||
class CreateTextMask:
|
class CreateTextMask:
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK",)
|
RETURN_TYPES = ("IMAGE", "MASK",)
|
||||||
|
|||||||
@ -4,3 +4,4 @@ pillow>=10.3.0
|
|||||||
scipy
|
scipy
|
||||||
color-matcher
|
color-matcher
|
||||||
matplotlib
|
matplotlib
|
||||||
|
huggingface_hub
|
||||||
Loading…
x
Reference in New Issue
Block a user