diff --git a/nodes.py b/nodes.py index f4f3704..c4b9ea1 100644 --- a/nodes.py +++ b/nodes.py @@ -2472,9 +2472,11 @@ Segments an image or batch of images using CLIPSeg. tensor = torch.sigmoid(outputs[0]) tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) - tensor = tensor_normalized.unsqueeze(0).unsqueeze(0) + tensor = tensor_normalized # Resize the mask + if len(tensor.shape) == 3: + tensor = tensor.unsqueeze(0) resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest') # Remove the extra dimensions