From b2d5ab3fcdc81c807f4049255abb65d38b608992 Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 13 May 2024 15:06:38 +0300 Subject: [PATCH] Update mask_nodes.py --- nodes/mask_nodes.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 44ea4ff..0f76c70 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1130,10 +1130,9 @@ Resizes the mask or batch of masks to the specified width and height. ratio = min(width / ow, height / oh) width = round(ow*ratio) height = round(oh*ratio) - - outputs = mask.unsqueeze(0) # Add an extra dimension for batch size + outputs = mask.unsqueeze(1) outputs = F.interpolate(outputs, size=(height, width), mode="nearest") - outputs = outputs.squeeze(0) # Remove the extra dimension after interpolation + outputs = outputs.squeeze(1) return(outputs, outputs.shape[2], outputs.shape[1],)