Add ImagePadKJ, VAELoaderKJ

simple pad node and VAE loader that let's you choose device and dtype
This commit is contained in:
kijai 2025-02-21 20:35:23 +02:00
parent f653a8e45e
commit 8950c5fe67
3 changed files with 205 additions and 9 deletions

View File

@ -71,6 +71,7 @@ NODE_CONFIG = {
"ImageNoiseAugmentation": {"class": ImageNoiseAugmentation, "name": "Image Noise Augmentation"},
"ImageNormalize_Neg1_To_1": {"class": ImageNormalize_Neg1_To_1, "name": "Image Normalize -1 to 1"},
"ImagePass": {"class": ImagePass},
"ImagePadKJ": {"class": ImagePadKJ, "name": "ImagePad KJ"},
"ImagePadForOutpaintMasked": {"class": ImagePadForOutpaintMasked, "name": "Image Pad For Outpaint Masked"},
"ImagePadForOutpaintTargetSize": {"class": ImagePadForOutpaintTargetSize, "name": "Image Pad For Outpaint Target Size"},
"ImagePrepForICLora": {"class": ImagePrepForICLora, "name": "Image Prep For ICLora"},
@ -176,6 +177,7 @@ NODE_CONFIG = {
"TorchCompileCosmosModel": {"class": TorchCompileCosmosModel, "name": "TorchCompileCosmosModel"},
"PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"},
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -1105,6 +1105,7 @@ class ImagePrepForICLora:
},
"optional": {
"latent_image": ("IMAGE",),
"latent_mask": ("MASK",),
"reference_mask": ("MASK",),
}
}
@ -1114,7 +1115,7 @@ class ImagePrepForICLora:
CATEGORY = "image"
def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None):
def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None, latent_mask=None):
if reference_mask is not None:
if torch.allclose(reference_mask, torch.zeros_like(reference_mask)):
@ -1127,7 +1128,7 @@ class ImagePrepForICLora:
if reference_mask is not None:
resized_mask = torch.nn.functional.interpolate(
reference_mask.unsqueeze(1),
size=(image.shape[1], image.shape[2]),
size=(H, W),
mode='nearest'
).squeeze(1)
print(resized_mask.shape)
@ -1145,16 +1146,30 @@ class ImagePrepForICLora:
else:
resized_latent_image = common_upscale(latent_image.movedim(-1,1), output_width, output_height, "lanczos", "disabled").movedim(1,-1)
pad_image = resized_latent_image
if latent_mask is not None:
resized_latent_mask = torch.nn.functional.interpolate(
latent_mask.unsqueeze(1),
size=(pad_image.shape[1], pad_image.shape[2]),
mode='nearest'
).squeeze(1)
if border_width > 0:
border = torch.zeros((B, output_height, border_width, C), device=image.device)
padded_image = torch.cat((resized_image, border, pad_image), dim=2)
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, :new_width + border_width] = 0
if latent_mask is not None:
padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, (new_width + border_width):] = resized_latent_mask
else:
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, :new_width + border_width] = 0
else:
padded_image = torch.cat((resized_image, pad_image), dim=2)
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, :new_width] = 0
if latent_mask is not None:
padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, new_width:] = resized_latent_mask
else:
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, :new_width] = 0
return (padded_image, padded_mask)
@ -3059,3 +3074,82 @@ class ImageCropByMaskBatch:
out_rgb = out_rgb * mask_expanded + background_color * (1 - mask_expanded)
return (out_rgb, out_masks)
class ImagePadKJ:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE", ),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
"extra_padding": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
"pad_mode": (["edge", "color"],),
"color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}),
}
, "optional": {
"masks": ("MASK", ),
}
}
RETURN_TYPES = ("IMAGE", "MASK", )
RETURN_NAMES = ("images", "masks",)
FUNCTION = "pad"
CATEGORY = "KJNodes/image"
DESCRIPTION = "Crops the input images based on the provided masks."
def pad(self, image, left, right, top, bottom, extra_padding, color, pad_mode, mask=None):
B, H, W, C = image.shape
# Resize masks to image dimensions if necessary
if mask is not None:
BM, HM, WM = mask.shape
if HM != H or WM != W:
mask = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
# Parse background color
bg_color = [int(x.strip())/255.0 for x in color.split(",")]
if len(bg_color) == 1:
bg_color = bg_color * 3 # Grayscale to RGB
bg_color = torch.tensor(bg_color, dtype=image.dtype, device=image.device)
# Calculate padding sizes with extra padding
pad_left = left + extra_padding
pad_right = right + extra_padding
pad_top = top + extra_padding
pad_bottom = bottom + extra_padding
padded_width = W + pad_left + pad_right
padded_height = H + pad_top + pad_bottom
out_image = torch.zeros((B, padded_height, padded_width, C), dtype=image.dtype, device=image.device)
# Fill padded areas
for b in range(B):
if pad_mode == "edge":
# Pad with edge color
# Define edge pixels
top_edge = image[b, 0, :, :]
bottom_edge = image[b, H-1, :, :]
left_edge = image[b, :, 0, :]
right_edge = image[b, :, W-1, :]
# Fill borders with edge colors
out_image[b, :pad_top, :, :] = top_edge.mean(dim=0)
out_image[b, pad_top+H:, :, :] = bottom_edge.mean(dim=0)
out_image[b, :, :pad_left, :] = left_edge.mean(dim=0)
out_image[b, :, pad_left+W:, :] = right_edge.mean(dim=0)
out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
else:
# Pad with specified background color
out_image[b, :, :, :] = bg_color.unsqueeze(0).unsqueeze(0) # Expand for H and W dimensions
out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
if mask is not None:
out_masks = torch.zeros((BM, padded_height, padded_width), dtype=mask.dtype, device=mask.device)
for m in range(BM):
out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = mask[m]
else:
out_masks = torch.zeros((1, padded_height, padded_width), dtype=image.dtype, device=image.device)
return (out_image, out_masks)

View File

@ -9,7 +9,7 @@ import importlib
import model_management
import folder_paths
from nodes import MAX_RESOLUTION
from comfy.utils import common_upscale, ProgressBar
from comfy.utils import common_upscale, ProgressBar, load_torch_file
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
folder_paths.add_model_folder_path("kjnodes_fonts", os.path.join(script_directory, "fonts"))
@ -1964,7 +1964,7 @@ class FluxBlockLoraLoader:
CATEGORY = "KJNodes/experimental"
def load_lora(self, model, strength_model, lora_name=None, opt_lora_path=None, blocks=None):
from comfy.utils import load_torch_file
import comfy.lora
if opt_lora_path:
@ -2299,4 +2299,104 @@ class ImageNoiseAugmentation:
image_noise = torch.randn_like(image) * sigma[:, None, None, None]
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
image_out = image + image_noise
return image_out,
return image_out,
class VAELoaderKJ:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
sd3_taesd_enc = False
sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
elif v.startswith("taesd3_decoder."):
sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True
elif v.startswith("taef1_encoder."):
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
if sd3_taesd_dec and sd3_taesd_enc:
vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
sd["vae_shift"] = torch.tensor(0.0)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
sd["vae_shift"] = torch.tensor(0.0)
elif name == "taesd3":
sd["vae_scale"] = torch.tensor(1.5305)
sd["vae_shift"] = torch.tensor(0.0609)
elif name == "taef1":
sd["vae_scale"] = torch.tensor(0.3611)
sd["vae_shift"] = torch.tensor(0.1159)
return sd
@classmethod
def INPUT_TYPES(s):
return {
"required": { "vae_name": (s.vae_list(), ),
"device": (["main_device", "cpu"],),
"weight_dtype": (["bf16", "fp16", "fp32" ],),
}
}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
CATEGORY = "KJNodes/vae"
def load_vae(self, vae_name, device, weight_dtype):
from comfy.sd import VAE
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[weight_dtype]
if device == "main_device":
device = model_management.get_torch_device()
elif device == "cpu":
device = torch.device("cpu")
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = load_torch_file(vae_path)
vae = VAE(sd=sd, device=device, dtype=dtype)
return (vae,)