diff --git a/nodes.py b/nodes.py index 6d1bc22..c7644bb 100644 --- a/nodes.py +++ b/nodes.py @@ -2250,11 +2250,8 @@ class BatchCLIPSeg: tensor = tensor_normalized - # Add extra dimensions to the mask for batch and channel - tensor = tensor[None, None, :, :] - # Resize the mask - resized_tensor = F.interpolate(tensor, size=(height, width), mode='bilinear', align_corners=False) + resized_tensor = F.interpolate(tensor.unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False) # Remove the extra dimensions resized_tensor = resized_tensor[0, 0, :, :]