Blur with PIL instead

This commit is contained in:
kijai 2024-01-01 20:30:58 +02:00
parent 6010472e5f
commit da613ecaa5

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torchvision.transforms import Resize, CenterCrop, InterpolationMode from torchvision.transforms import Resize, CenterCrop, InterpolationMode
from torchvision.transforms import functional as TF from torchvision.transforms import functional as TF
import scipy.ndimage import scipy.ndimage
from scipy.spatial import Voronoi, voronoi_plot_2d from scipy.spatial import Voronoi
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from PIL import ImageFilter, Image, ImageDraw, ImageFont from PIL import ImageFilter, Image, ImageDraw, ImageFont
@ -638,19 +638,12 @@ class GrowMaskWithBlur:
"incremental_expandrate": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1}), "incremental_expandrate": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1}),
"tapered_corners": ("BOOLEAN", {"default": True}), "tapered_corners": ("BOOLEAN", {"default": True}),
"flip_input": ("BOOLEAN", {"default": False}), "flip_input": ("BOOLEAN", {"default": False}),
"use_cuda": ("BOOLEAN", {"default": True}),
"blur_radius": ("INT", { "blur_radius": ("INT", {
"default": 0, "default": 0,
"min": 0, "min": 0,
"max": 999, "max": 999,
"step": 1 "step": 1
}), }),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}, },
@ -662,7 +655,7 @@ class GrowMaskWithBlur:
RETURN_NAMES = ("mask", "mask_inverted",) RETURN_NAMES = ("mask", "mask_inverted",)
FUNCTION = "expand_mask" FUNCTION = "expand_mask"
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda, lerp_alpha, decay_factor): def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor):
alpha = lerp_alpha alpha = lerp_alpha
decay = decay_factor decay = decay_factor
if( flip_input ): if( flip_input ):
@ -696,22 +689,18 @@ class GrowMaskWithBlur:
previous_output = output previous_output = output
out.append(output) out.append(output)
blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
if use_cuda:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
blurred = blurred.to(device) # Move blurred tensor to the GPU
channels = blurred.shape[-1]
if blur_radius != 0: if blur_radius != 0:
blurkernel_size = blur_radius * 2 + 1 # Convert the tensor list to PIL images, apply blur, and convert back
blurkernel = gaussian_kernel(blurkernel_size, sigma, device=blurred.device).repeat(channels, 1, 1).unsqueeze(1) for idx, tensor in enumerate(out):
blurred = blurred.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) # Convert tensor to PIL image
padded_image = F.pad(blurred, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') pil_image = TF.to_pil_image(tensor.cpu().detach())
blurred = F.conv2d(padded_image, blurkernel, padding=blurkernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] # Apply Gaussian blur
blurred = blurred.permute(0, 2, 3, 1) pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius))
blurred = blurred[:, :, :, 0] # Convert back to tensor
return (blurred.cpu(), 1.0 - blurred.cpu(),) out[idx] = TF.to_tensor(pil_image)
return (torch.stack(out, dim=0), 1.0 -torch.stack(out, dim=0),) blurred = torch.stack(out, dim=0)
return (blurred, 1.0 - blurred)
@ -2135,6 +2124,12 @@ class OffsetMask:
return mask, return mask,
class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
any = AnyType("*")
class WidgetToString: class WidgetToString:
@classmethod @classmethod
def IS_CHANGED(cls, **kwargs): def IS_CHANGED(cls, **kwargs):
@ -2148,6 +2143,9 @@ class WidgetToString:
"widget_name": ("STRING", {"multiline": False}), "widget_name": ("STRING", {"multiline": False}),
"return_all": ("BOOLEAN", {"default": False}), "return_all": ("BOOLEAN", {"default": False}),
}, },
"optional": {
"source": (any, {}),
},
"hidden": {"extra_pnginfo": "EXTRA_PNGINFO", "hidden": {"extra_pnginfo": "EXTRA_PNGINFO",
"prompt": "PROMPT"}, "prompt": "PROMPT"},
} }
@ -2939,7 +2937,20 @@ class ImageBatchRepeatInterleaving:
repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0) repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0)
return (repeated_images, ) return (repeated_images, )
class MarigoldVAELoader:
#Paste stuffs from VAELoader nodes
def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
#Override the regularization
vae.first_stage_model.regularization = lambda x: torch.chunk(x, dim=2)
return (vae,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant, "INTConstant": INTConstant,
"FloatConstant": FloatConstant, "FloatConstant": FloatConstant,
@ -2994,7 +3005,8 @@ NODE_CLASS_MAPPINGS = {
"GenerateNoise": GenerateNoise, "GenerateNoise": GenerateNoise,
"StableZero123_BatchSchedule": StableZero123_BatchSchedule, "StableZero123_BatchSchedule": StableZero123_BatchSchedule,
"GetImagesFromBatchIndexed": GetImagesFromBatchIndexed, "GetImagesFromBatchIndexed": GetImagesFromBatchIndexed,
"ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving "ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving,
"MarigoldVAELoader": MarigoldVAELoader
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant", "INTConstant": "INT Constant",
@ -3049,5 +3061,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GenerateNoise": "GenerateNoise", "GenerateNoise": "GenerateNoise",
"StableZero123_BatchSchedule": "StableZero123_BatchSchedule", "StableZero123_BatchSchedule": "StableZero123_BatchSchedule",
"GetImagesFromBatchIndexed": "GetImagesFromBatchIndexed", "GetImagesFromBatchIndexed": "GetImagesFromBatchIndexed",
"ImageBatchRepeatInterleaving": "ImageBatchRepeatInterleaving" "ImageBatchRepeatInterleaving": "ImageBatchRepeatInterleaving",
"MarigoldVAELoader": "MarigoldVAELoader"
} }