mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-31 20:27:09 +08:00
Update mask_nodes.py
This commit is contained in:
parent
68471f65b3
commit
6ca2bb2708
@ -56,9 +56,8 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
if use_cuda and torch.cuda.is_available():
|
device = model_management.get_torch_device()
|
||||||
device = model_management.get_torch_device()
|
if not use_cuda:
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = model_management.unet_dtype()
|
dtype = model_management.unet_dtype()
|
||||||
|
|
||||||
@ -96,11 +95,9 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
outputs = self.model(**input_prc)
|
outputs = self.model(**input_prc)
|
||||||
|
|
||||||
tensor = torch.sigmoid(outputs.logits)
|
tensor = torch.sigmoid(outputs.logits)
|
||||||
print(tensor.min(), tensor.max())
|
|
||||||
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
||||||
tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float))
|
tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float))
|
||||||
|
|
||||||
|
|
||||||
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
||||||
tensor = tensor.squeeze(1)
|
tensor = tensor.squeeze(1)
|
||||||
|
|
||||||
@ -121,9 +118,12 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
model_management.soft_empty_cache()
|
model_management.soft_empty_cache()
|
||||||
|
|
||||||
if prev_mask is not None:
|
if prev_mask is not None:
|
||||||
tensor = tensor + prev_mask
|
if prev_mask.shape != tensor.shape:
|
||||||
|
prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest')
|
||||||
|
tensor = tensor + prev_mask.to(device)
|
||||||
torch.clamp(tensor, min=0.0, max=1.0)
|
torch.clamp(tensor, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
tensor = tensor.cpu().float()
|
||||||
return tensor,
|
return tensor,
|
||||||
|
|
||||||
class DownloadAndLoadCLIPSeg:
|
class DownloadAndLoadCLIPSeg:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user