mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-27 17:28:12 +08:00
Update mask_nodes.py
This commit is contained in:
parent
68471f65b3
commit
6ca2bb2708
@ -56,9 +56,8 @@ Segments an image or batch of images using CLIPSeg.
|
||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||
import torchvision.transforms as transforms
|
||||
offload_device = model_management.unet_offload_device()
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
device = model_management.get_torch_device()
|
||||
else:
|
||||
device = model_management.get_torch_device()
|
||||
if not use_cuda:
|
||||
device = torch.device("cpu")
|
||||
dtype = model_management.unet_dtype()
|
||||
|
||||
@ -96,11 +95,9 @@ Segments an image or batch of images using CLIPSeg.
|
||||
outputs = self.model(**input_prc)
|
||||
|
||||
tensor = torch.sigmoid(outputs.logits)
|
||||
print(tensor.min(), tensor.max())
|
||||
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
||||
tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float))
|
||||
|
||||
|
||||
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
||||
tensor = tensor.squeeze(1)
|
||||
|
||||
@ -121,9 +118,12 @@ Segments an image or batch of images using CLIPSeg.
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if prev_mask is not None:
|
||||
tensor = tensor + prev_mask
|
||||
if prev_mask.shape != tensor.shape:
|
||||
prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest')
|
||||
tensor = tensor + prev_mask.to(device)
|
||||
torch.clamp(tensor, min=0.0, max=1.0)
|
||||
|
||||
tensor = tensor.cpu().float()
|
||||
return tensor,
|
||||
|
||||
class DownloadAndLoadCLIPSeg:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user