diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 44ea4ff..0f76c70 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1130,10 +1130,9 @@ Resizes the mask or batch of masks to the specified width and height. ratio = min(width / ow, height / oh) width = round(ow*ratio) height = round(oh*ratio) - - outputs = mask.unsqueeze(0) # Add an extra dimension for batch size + outputs = mask.unsqueeze(1) outputs = F.interpolate(outputs, size=(height, width), mode="nearest") - outputs = outputs.squeeze(0) # Remove the extra dimension after interpolation + outputs = outputs.squeeze(1) return(outputs, outputs.shape[2], outputs.shape[1],)