diff --git a/__init__.py b/__init__.py index 0f6dfa2..8f2cfcf 100644 --- a/__init__.py +++ b/__init__.py @@ -71,6 +71,7 @@ NODE_CONFIG = { "ImageNoiseAugmentation": {"class": ImageNoiseAugmentation, "name": "Image Noise Augmentation"}, "ImageNormalize_Neg1_To_1": {"class": ImageNormalize_Neg1_To_1, "name": "Image Normalize -1 to 1"}, "ImagePass": {"class": ImagePass}, + "ImagePadKJ": {"class": ImagePadKJ, "name": "ImagePad KJ"}, "ImagePadForOutpaintMasked": {"class": ImagePadForOutpaintMasked, "name": "Image Pad For Outpaint Masked"}, "ImagePadForOutpaintTargetSize": {"class": ImagePadForOutpaintTargetSize, "name": "Image Pad For Outpaint Target Size"}, "ImagePrepForICLora": {"class": ImagePrepForICLora, "name": "Image Prep For ICLora"}, @@ -176,6 +177,7 @@ NODE_CONFIG = { "TorchCompileCosmosModel": {"class": TorchCompileCosmosModel, "name": "TorchCompileCosmosModel"}, "PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"}, "LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"}, + "VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index ae52eae..dee5158 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1105,6 +1105,7 @@ class ImagePrepForICLora: }, "optional": { "latent_image": ("IMAGE",), + "latent_mask": ("MASK",), "reference_mask": ("MASK",), } } @@ -1114,7 +1115,7 @@ class ImagePrepForICLora: CATEGORY = "image" - def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None): + def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None, latent_mask=None): if reference_mask is not None: if torch.allclose(reference_mask, torch.zeros_like(reference_mask)): @@ -1127,7 +1128,7 @@ class ImagePrepForICLora: if reference_mask is not None: resized_mask = torch.nn.functional.interpolate( reference_mask.unsqueeze(1), - size=(image.shape[1], image.shape[2]), + size=(H, W), mode='nearest' ).squeeze(1) print(resized_mask.shape) @@ -1145,16 +1146,30 @@ class ImagePrepForICLora: else: resized_latent_image = common_upscale(latent_image.movedim(-1,1), output_width, output_height, "lanczos", "disabled").movedim(1,-1) pad_image = resized_latent_image + if latent_mask is not None: + resized_latent_mask = torch.nn.functional.interpolate( + latent_mask.unsqueeze(1), + size=(pad_image.shape[1], pad_image.shape[2]), + mode='nearest' + ).squeeze(1) if border_width > 0: border = torch.zeros((B, output_height, border_width, C), device=image.device) padded_image = torch.cat((resized_image, border, pad_image), dim=2) - padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) - padded_mask[:, :, :new_width + border_width] = 0 + if latent_mask is not None: + padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) + padded_mask[:, :, (new_width + border_width):] = resized_latent_mask + else: + padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) + padded_mask[:, :, :new_width + border_width] = 0 else: padded_image = torch.cat((resized_image, pad_image), dim=2) - padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) - padded_mask[:, :, :new_width] = 0 + if latent_mask is not None: + padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) + padded_mask[:, :, new_width:] = resized_latent_mask + else: + padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) + padded_mask[:, :, :new_width] = 0 return (padded_image, padded_mask) @@ -3059,3 +3074,82 @@ class ImageCropByMaskBatch: out_rgb = out_rgb * mask_expanded + background_color * (1 - mask_expanded) return (out_rgb, out_masks) + +class ImagePadKJ: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE", ), + "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), + "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), + "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), + "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), + "extra_padding": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), + "pad_mode": (["edge", "color"],), + "color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}), + } + , "optional": { + "masks": ("MASK", ), + } + } + + RETURN_TYPES = ("IMAGE", "MASK", ) + RETURN_NAMES = ("images", "masks",) + FUNCTION = "pad" + CATEGORY = "KJNodes/image" + DESCRIPTION = "Crops the input images based on the provided masks." + + def pad(self, image, left, right, top, bottom, extra_padding, color, pad_mode, mask=None): + B, H, W, C = image.shape + + # Resize masks to image dimensions if necessary + if mask is not None: + BM, HM, WM = mask.shape + if HM != H or WM != W: + mask = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1) + + # Parse background color + bg_color = [int(x.strip())/255.0 for x in color.split(",")] + if len(bg_color) == 1: + bg_color = bg_color * 3 # Grayscale to RGB + bg_color = torch.tensor(bg_color, dtype=image.dtype, device=image.device) + + # Calculate padding sizes with extra padding + pad_left = left + extra_padding + pad_right = right + extra_padding + pad_top = top + extra_padding + pad_bottom = bottom + extra_padding + + padded_width = W + pad_left + pad_right + padded_height = H + pad_top + pad_bottom + out_image = torch.zeros((B, padded_height, padded_width, C), dtype=image.dtype, device=image.device) + + # Fill padded areas + for b in range(B): + if pad_mode == "edge": + # Pad with edge color + # Define edge pixels + top_edge = image[b, 0, :, :] + bottom_edge = image[b, H-1, :, :] + left_edge = image[b, :, 0, :] + right_edge = image[b, :, W-1, :] + + # Fill borders with edge colors + out_image[b, :pad_top, :, :] = top_edge.mean(dim=0) + out_image[b, pad_top+H:, :, :] = bottom_edge.mean(dim=0) + out_image[b, :, :pad_left, :] = left_edge.mean(dim=0) + out_image[b, :, pad_left+W:, :] = right_edge.mean(dim=0) + out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b] + else: + # Pad with specified background color + out_image[b, :, :, :] = bg_color.unsqueeze(0).unsqueeze(0) # Expand for H and W dimensions + out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b] + + if mask is not None: + out_masks = torch.zeros((BM, padded_height, padded_width), dtype=mask.dtype, device=mask.device) + for m in range(BM): + out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = mask[m] + else: + out_masks = torch.zeros((1, padded_height, padded_width), dtype=image.dtype, device=image.device) + + return (out_image, out_masks) diff --git a/nodes/nodes.py b/nodes/nodes.py index 6e1d4e9..b625697 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -9,7 +9,7 @@ import importlib import model_management import folder_paths from nodes import MAX_RESOLUTION -from comfy.utils import common_upscale, ProgressBar +from comfy.utils import common_upscale, ProgressBar, load_torch_file script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) folder_paths.add_model_folder_path("kjnodes_fonts", os.path.join(script_directory, "fonts")) @@ -1964,7 +1964,7 @@ class FluxBlockLoraLoader: CATEGORY = "KJNodes/experimental" def load_lora(self, model, strength_model, lora_name=None, opt_lora_path=None, blocks=None): - from comfy.utils import load_torch_file + import comfy.lora if opt_lora_path: @@ -2299,4 +2299,104 @@ class ImageNoiseAugmentation: image_noise = torch.randn_like(image) * sigma[:, None, None, None] image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) image_out = image + image_noise - return image_out, \ No newline at end of file + return image_out, + +class VAELoaderKJ: + @staticmethod + def vae_list(): + vaes = folder_paths.get_filename_list("vae") + approx_vaes = folder_paths.get_filename_list("vae_approx") + sdxl_taesd_enc = False + sdxl_taesd_dec = False + sd1_taesd_enc = False + sd1_taesd_dec = False + sd3_taesd_enc = False + sd3_taesd_dec = False + f1_taesd_enc = False + f1_taesd_dec = False + + for v in approx_vaes: + if v.startswith("taesd_decoder."): + sd1_taesd_dec = True + elif v.startswith("taesd_encoder."): + sd1_taesd_enc = True + elif v.startswith("taesdxl_decoder."): + sdxl_taesd_dec = True + elif v.startswith("taesdxl_encoder."): + sdxl_taesd_enc = True + elif v.startswith("taesd3_decoder."): + sd3_taesd_dec = True + elif v.startswith("taesd3_encoder."): + sd3_taesd_enc = True + elif v.startswith("taef1_encoder."): + f1_taesd_dec = True + elif v.startswith("taef1_decoder."): + f1_taesd_enc = True + if sd1_taesd_dec and sd1_taesd_enc: + vaes.append("taesd") + if sdxl_taesd_dec and sdxl_taesd_enc: + vaes.append("taesdxl") + if sd3_taesd_dec and sd3_taesd_enc: + vaes.append("taesd3") + if f1_taesd_dec and f1_taesd_enc: + vaes.append("taef1") + return vaes + + @staticmethod + def load_taesd(name): + sd = {} + approx_vaes = folder_paths.get_filename_list("vae_approx") + + encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) + decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) + + enc = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder)) + for k in enc: + sd["taesd_encoder.{}".format(k)] = enc[k] + + dec = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder)) + for k in dec: + sd["taesd_decoder.{}".format(k)] = dec[k] + + if name == "taesd": + sd["vae_scale"] = torch.tensor(0.18215) + sd["vae_shift"] = torch.tensor(0.0) + elif name == "taesdxl": + sd["vae_scale"] = torch.tensor(0.13025) + sd["vae_shift"] = torch.tensor(0.0) + elif name == "taesd3": + sd["vae_scale"] = torch.tensor(1.5305) + sd["vae_shift"] = torch.tensor(0.0609) + elif name == "taef1": + sd["vae_scale"] = torch.tensor(0.3611) + sd["vae_shift"] = torch.tensor(0.1159) + return sd + + @classmethod + def INPUT_TYPES(s): + return { + "required": { "vae_name": (s.vae_list(), ), + "device": (["main_device", "cpu"],), + "weight_dtype": (["bf16", "fp16", "fp32" ],), + } + } + + RETURN_TYPES = ("VAE",) + FUNCTION = "load_vae" + + CATEGORY = "KJNodes/vae" + + def load_vae(self, vae_name, device, weight_dtype): + from comfy.sd import VAE + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[weight_dtype] + if device == "main_device": + device = model_management.get_torch_device() + elif device == "cpu": + device = torch.device("cpu") + if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + sd = self.load_taesd(vae_name) + else: + vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) + sd = load_torch_file(vae_path) + vae = VAE(sd=sd, device=device, dtype=dtype) + return (vae,) \ No newline at end of file