Fix GrowMaskWithBlur output shape

This commit is contained in:
kijai 2024-01-05 19:02:10 +02:00
parent ae9f3b1f0c
commit 6f9bbe3ccf

View File

@ -693,13 +693,15 @@ class GrowMaskWithBlur:
# Convert the tensor list to PIL images, apply blur, and convert back
for idx, tensor in enumerate(out):
# Convert tensor to PIL image
pil_image = TF.to_pil_image(tensor.cpu().detach())
#pil_image = TF.to_pil_image(tensor.cpu().detach())
pil_image = tensor2pil(tensor.cpu().detach())[0]
# Apply Gaussian blur
pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius))
# Convert back to tensor
out[idx] = TF.to_tensor(pil_image)
blurred = torch.stack(out, dim=0)
out[idx] = pil2tensor(pil_image)
blurred = torch.cat(out, dim=0)
print(blurred.shape)
return (blurred, 1.0 - blurred)