diff --git a/nodes.py b/nodes.py index 5bb7995..a1098cb 100644 --- a/nodes.py +++ b/nodes.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torchvision.transforms import Resize, CenterCrop, InterpolationMode from torchvision.transforms import functional as TF import scipy.ndimage -from scipy.spatial import Voronoi, voronoi_plot_2d +from scipy.spatial import Voronoi import matplotlib.pyplot as plt import numpy as np from PIL import ImageFilter, Image, ImageDraw, ImageFont @@ -638,19 +638,12 @@ class GrowMaskWithBlur: "incremental_expandrate": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1}), "tapered_corners": ("BOOLEAN", {"default": True}), "flip_input": ("BOOLEAN", {"default": False}), - "use_cuda": ("BOOLEAN", {"default": True}), "blur_radius": ("INT", { "default": 0, "min": 0, "max": 999, "step": 1 }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), "lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, @@ -662,7 +655,7 @@ class GrowMaskWithBlur: RETURN_NAMES = ("mask", "mask_inverted",) FUNCTION = "expand_mask" - def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda, lerp_alpha, decay_factor): + def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor): alpha = lerp_alpha decay = decay_factor if( flip_input ): @@ -696,22 +689,18 @@ class GrowMaskWithBlur: previous_output = output out.append(output) - blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - if use_cuda: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - blurred = blurred.to(device) # Move blurred tensor to the GPU - - channels = blurred.shape[-1] if blur_radius != 0: - blurkernel_size = blur_radius * 2 + 1 - blurkernel = gaussian_kernel(blurkernel_size, sigma, device=blurred.device).repeat(channels, 1, 1).unsqueeze(1) - blurred = blurred.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - padded_image = F.pad(blurred, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') - blurred = F.conv2d(padded_image, blurkernel, padding=blurkernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] - blurred = blurred.permute(0, 2, 3, 1) - blurred = blurred[:, :, :, 0] - return (blurred.cpu(), 1.0 - blurred.cpu(),) - return (torch.stack(out, dim=0), 1.0 -torch.stack(out, dim=0),) + # Convert the tensor list to PIL images, apply blur, and convert back + for idx, tensor in enumerate(out): + # Convert tensor to PIL image + pil_image = TF.to_pil_image(tensor.cpu().detach()) + # Apply Gaussian blur + pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) + # Convert back to tensor + out[idx] = TF.to_tensor(pil_image) + blurred = torch.stack(out, dim=0) + + return (blurred, 1.0 - blurred) @@ -2135,6 +2124,12 @@ class OffsetMask: return mask, +class AnyType(str): + """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" + + def __ne__(self, __value: object) -> bool: + return False +any = AnyType("*") class WidgetToString: @classmethod def IS_CHANGED(cls, **kwargs): @@ -2148,6 +2143,9 @@ class WidgetToString: "widget_name": ("STRING", {"multiline": False}), "return_all": ("BOOLEAN", {"default": False}), }, + "optional": { + "source": (any, {}), + }, "hidden": {"extra_pnginfo": "EXTRA_PNGINFO", "prompt": "PROMPT"}, } @@ -2939,7 +2937,20 @@ class ImageBatchRepeatInterleaving: repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0) return (repeated_images, ) - + +class MarigoldVAELoader: + #Paste stuffs from VAELoader nodes + def load_vae(self, vae_name): + if vae_name in ["taesd", "taesdxl"]: + sd = self.load_taesd(vae_name) + else: + vae_path = folder_paths.get_full_path("vae", vae_name) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) + #Override the regularization + vae.first_stage_model.regularization = lambda x: torch.chunk(x, dim=2) + return (vae,) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -2994,7 +3005,8 @@ NODE_CLASS_MAPPINGS = { "GenerateNoise": GenerateNoise, "StableZero123_BatchSchedule": StableZero123_BatchSchedule, "GetImagesFromBatchIndexed": GetImagesFromBatchIndexed, - "ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving + "ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving, + "MarigoldVAELoader": MarigoldVAELoader } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -3049,5 +3061,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "GenerateNoise": "GenerateNoise", "StableZero123_BatchSchedule": "StableZero123_BatchSchedule", "GetImagesFromBatchIndexed": "GetImagesFromBatchIndexed", - "ImageBatchRepeatInterleaving": "ImageBatchRepeatInterleaving" + "ImageBatchRepeatInterleaving": "ImageBatchRepeatInterleaving", + "MarigoldVAELoader": "MarigoldVAELoader" } \ No newline at end of file