From 60cd6b555a5b7c0fc84f90b0c849d6f3100fbee6 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 23 Apr 2024 00:47:18 +0300 Subject: [PATCH] Update nodes.py --- nodes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index f4f3704..c4b9ea1 100644 --- a/nodes.py +++ b/nodes.py @@ -2472,9 +2472,11 @@ Segments an image or batch of images using CLIPSeg. tensor = torch.sigmoid(outputs[0]) tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) - tensor = tensor_normalized.unsqueeze(0).unsqueeze(0) + tensor = tensor_normalized # Resize the mask + if len(tensor.shape) == 3: + tensor = tensor.unsqueeze(0) resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest') # Remove the extra dimensions