Update mask_nodes.py

This commit is contained in:
kijai 2024-05-13 22:58:17 +03:00
parent 68471f65b3
commit 6ca2bb2708

View File

@ -56,9 +56,8 @@ Segments an image or batch of images using CLIPSeg.
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torchvision.transforms as transforms
offload_device = model_management.unet_offload_device()
if use_cuda and torch.cuda.is_available():
device = model_management.get_torch_device()
else:
device = model_management.get_torch_device()
if not use_cuda:
device = torch.device("cpu")
dtype = model_management.unet_dtype()
@ -96,11 +95,9 @@ Segments an image or batch of images using CLIPSeg.
outputs = self.model(**input_prc)
tensor = torch.sigmoid(outputs.logits)
print(tensor.min(), tensor.max())
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float))
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
tensor = tensor.squeeze(1)
@ -121,9 +118,12 @@ Segments an image or batch of images using CLIPSeg.
model_management.soft_empty_cache()
if prev_mask is not None:
tensor = tensor + prev_mask
if prev_mask.shape != tensor.shape:
prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest')
tensor = tensor + prev_mask.to(device)
torch.clamp(tensor, min=0.0, max=1.0)
tensor = tensor.cpu().float()
return tensor,
class DownloadAndLoadCLIPSeg: