From 6ca2bb27086610b54df8ea882095313601953bf1 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 13 May 2024 22:58:17 +0300 Subject: [PATCH] Update mask_nodes.py --- nodes/mask_nodes.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index e39c09e..4ed74cb 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -56,9 +56,8 @@ Segments an image or batch of images using CLIPSeg. from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import torchvision.transforms as transforms offload_device = model_management.unet_offload_device() - if use_cuda and torch.cuda.is_available(): - device = model_management.get_torch_device() - else: + device = model_management.get_torch_device() + if not use_cuda: device = torch.device("cpu") dtype = model_management.unet_dtype() @@ -96,11 +95,9 @@ Segments an image or batch of images using CLIPSeg. outputs = self.model(**input_prc) tensor = torch.sigmoid(outputs.logits) - print(tensor.min(), tensor.max()) tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float)) - tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest') tensor = tensor.squeeze(1) @@ -121,9 +118,12 @@ Segments an image or batch of images using CLIPSeg. model_management.soft_empty_cache() if prev_mask is not None: - tensor = tensor + prev_mask + if prev_mask.shape != tensor.shape: + prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest') + tensor = tensor + prev_mask.to(device) torch.clamp(tensor, min=0.0, max=1.0) + tensor = tensor.cpu().float() return tensor, class DownloadAndLoadCLIPSeg: