mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-01 08:45:42 +08:00
Update mask_nodes.py
This commit is contained in:
parent
6ca2bb2708
commit
33ef974370
@ -41,18 +41,20 @@ class BatchCLIPSeg:
|
|||||||
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||||
"opt_model": ("CLIPSEGMODEL", ),
|
"opt_model": ("CLIPSEGMODEL", ),
|
||||||
"prev_mask": ("MASK", {"default": None}),
|
"prev_mask": ("MASK", {"default": None}),
|
||||||
|
"image_bg_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"invert": ("BOOLEAN", {"default": False}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "KJNodes/masking"
|
CATEGORY = "KJNodes/masking"
|
||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK", "IMAGE", )
|
||||||
RETURN_NAMES = ("Mask",)
|
RETURN_NAMES = ("Mask", "Image", )
|
||||||
FUNCTION = "segment_image"
|
FUNCTION = "segment_image"
|
||||||
DESCRIPTION = """
|
DESCRIPTION = """
|
||||||
Segments an image or batch of images using CLIPSeg.
|
Segments an image or batch of images using CLIPSeg.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None):
|
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None, invert= False, image_bg_level=0.5):
|
||||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
@ -86,45 +88,54 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
|
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
|
||||||
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
|
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
|
||||||
|
|
||||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ]
|
PIL_images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ]
|
||||||
prompt = [text] * len(images)
|
prompt = [text] * len(images)
|
||||||
input_prc = processor(text=prompt, images=images, return_tensors="pt")
|
input_prc = processor(text=prompt, images=PIL_images, return_tensors="pt")
|
||||||
|
|
||||||
for key in input_prc:
|
for key in input_prc:
|
||||||
input_prc[key] = input_prc[key].to(device)
|
input_prc[key] = input_prc[key].to(device)
|
||||||
outputs = self.model(**input_prc)
|
outputs = self.model(**input_prc)
|
||||||
|
|
||||||
tensor = torch.sigmoid(outputs.logits)
|
mask_tensor = torch.sigmoid(outputs.logits)
|
||||||
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
mask_tensor = (mask_tensor - mask_tensor.min()) / (mask_tensor.max() - mask_tensor.min())
|
||||||
tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float))
|
mask_tensor = torch.where(mask_tensor > (threshold), mask_tensor, torch.tensor(0, dtype=torch.float))
|
||||||
|
print(mask_tensor.shape)
|
||||||
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
if len(mask_tensor.shape) == 2:
|
||||||
tensor = tensor.squeeze(1)
|
mask_tensor = mask_tensor.unsqueeze(0)
|
||||||
|
mask_tensor = F.interpolate(mask_tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
||||||
|
mask_tensor = mask_tensor.squeeze(1)
|
||||||
|
|
||||||
self.model.to(offload_device)
|
self.model.to(offload_device)
|
||||||
|
|
||||||
if binary_mask:
|
if binary_mask:
|
||||||
tensor = (tensor > 0).float()
|
mask_tensor = (mask_tensor > 0).float()
|
||||||
if blur_sigma > 0:
|
if blur_sigma > 0:
|
||||||
kernel_size = int(6 * int(blur_sigma) + 1)
|
kernel_size = int(6 * int(blur_sigma) + 1)
|
||||||
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
|
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
|
||||||
tensor = blur(tensor)
|
mask_tensor = blur(mask_tensor)
|
||||||
|
|
||||||
if combine_mask:
|
if combine_mask:
|
||||||
tensor = torch.max(tensor, dim=0)[0]
|
mask_tensor = torch.max(mask_tensor, dim=0)[0]
|
||||||
tensor = tensor.unsqueeze(0).repeat(len(images),1,1)
|
mask_tensor = mask_tensor.unsqueeze(0).repeat(len(images),1,1)
|
||||||
|
|
||||||
del outputs
|
del outputs
|
||||||
model_management.soft_empty_cache()
|
model_management.soft_empty_cache()
|
||||||
|
|
||||||
if prev_mask is not None:
|
if prev_mask is not None:
|
||||||
if prev_mask.shape != tensor.shape:
|
if prev_mask.shape != mask_tensor.shape:
|
||||||
prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest')
|
prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest')
|
||||||
tensor = tensor + prev_mask.to(device)
|
mask_tensor = mask_tensor + prev_mask.to(device)
|
||||||
torch.clamp(tensor, min=0.0, max=1.0)
|
torch.clamp(mask_tensor, min=0.0, max=1.0)
|
||||||
|
|
||||||
tensor = tensor.cpu().float()
|
if invert:
|
||||||
return tensor,
|
mask_tensor = 1 - mask_tensor
|
||||||
|
|
||||||
|
image_tensor = images * mask_tensor.unsqueeze(-1) + (1 - mask_tensor.unsqueeze(-1)) * image_bg_level
|
||||||
|
image_tensor = torch.clamp(image_tensor, min=0.0, max=1.0).cpu().float()
|
||||||
|
|
||||||
|
mask_tensor = mask_tensor.cpu().float()
|
||||||
|
|
||||||
|
return mask_tensor, image_tensor,
|
||||||
|
|
||||||
class DownloadAndLoadCLIPSeg:
|
class DownloadAndLoadCLIPSeg:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user