diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 347212c..e39c09e 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1195,10 +1195,9 @@ Resizes the mask or batch of masks to the specified width and height. ratio = min(width / ow, height / oh) width = round(ow*ratio) height = round(oh*ratio) - - outputs = mask.unsqueeze(0) # Add an extra dimension for batch size + outputs = mask.unsqueeze(1) outputs = F.interpolate(outputs, size=(height, width), mode="nearest") - outputs = outputs.squeeze(0) # Remove the extra dimension after interpolation + outputs = outputs.squeeze(1) return(outputs, outputs.shape[2], outputs.shape[1],)