From 6f9bbe3ccfe2e58f918302f4fb14ed6d06828916 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 5 Jan 2024 19:02:10 +0200 Subject: [PATCH] Fix GrowMaskWithBlur output shape --- nodes.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 780f1b9..45f3134 100644 --- a/nodes.py +++ b/nodes.py @@ -693,13 +693,15 @@ class GrowMaskWithBlur: # Convert the tensor list to PIL images, apply blur, and convert back for idx, tensor in enumerate(out): # Convert tensor to PIL image - pil_image = TF.to_pil_image(tensor.cpu().detach()) + #pil_image = TF.to_pil_image(tensor.cpu().detach()) + pil_image = tensor2pil(tensor.cpu().detach())[0] # Apply Gaussian blur pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) # Convert back to tensor - out[idx] = TF.to_tensor(pil_image) - blurred = torch.stack(out, dim=0) - + out[idx] = pil2tensor(pil_image) + + blurred = torch.cat(out, dim=0) + print(blurred.shape) return (blurred, 1.0 - blurred)