mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Add ImagePadKJ, VAELoaderKJ
simple pad node and VAE loader that let's you choose device and dtype
This commit is contained in:
parent
f653a8e45e
commit
8950c5fe67
@ -71,6 +71,7 @@ NODE_CONFIG = {
|
||||
"ImageNoiseAugmentation": {"class": ImageNoiseAugmentation, "name": "Image Noise Augmentation"},
|
||||
"ImageNormalize_Neg1_To_1": {"class": ImageNormalize_Neg1_To_1, "name": "Image Normalize -1 to 1"},
|
||||
"ImagePass": {"class": ImagePass},
|
||||
"ImagePadKJ": {"class": ImagePadKJ, "name": "ImagePad KJ"},
|
||||
"ImagePadForOutpaintMasked": {"class": ImagePadForOutpaintMasked, "name": "Image Pad For Outpaint Masked"},
|
||||
"ImagePadForOutpaintTargetSize": {"class": ImagePadForOutpaintTargetSize, "name": "Image Pad For Outpaint Target Size"},
|
||||
"ImagePrepForICLora": {"class": ImagePrepForICLora, "name": "Image Prep For ICLora"},
|
||||
@ -176,6 +177,7 @@ NODE_CONFIG = {
|
||||
"TorchCompileCosmosModel": {"class": TorchCompileCosmosModel, "name": "TorchCompileCosmosModel"},
|
||||
"PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"},
|
||||
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
|
||||
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -1105,6 +1105,7 @@ class ImagePrepForICLora:
|
||||
},
|
||||
"optional": {
|
||||
"latent_image": ("IMAGE",),
|
||||
"latent_mask": ("MASK",),
|
||||
"reference_mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
@ -1114,7 +1115,7 @@ class ImagePrepForICLora:
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None):
|
||||
def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None, latent_mask=None):
|
||||
|
||||
if reference_mask is not None:
|
||||
if torch.allclose(reference_mask, torch.zeros_like(reference_mask)):
|
||||
@ -1127,7 +1128,7 @@ class ImagePrepForICLora:
|
||||
if reference_mask is not None:
|
||||
resized_mask = torch.nn.functional.interpolate(
|
||||
reference_mask.unsqueeze(1),
|
||||
size=(image.shape[1], image.shape[2]),
|
||||
size=(H, W),
|
||||
mode='nearest'
|
||||
).squeeze(1)
|
||||
print(resized_mask.shape)
|
||||
@ -1145,16 +1146,30 @@ class ImagePrepForICLora:
|
||||
else:
|
||||
resized_latent_image = common_upscale(latent_image.movedim(-1,1), output_width, output_height, "lanczos", "disabled").movedim(1,-1)
|
||||
pad_image = resized_latent_image
|
||||
if latent_mask is not None:
|
||||
resized_latent_mask = torch.nn.functional.interpolate(
|
||||
latent_mask.unsqueeze(1),
|
||||
size=(pad_image.shape[1], pad_image.shape[2]),
|
||||
mode='nearest'
|
||||
).squeeze(1)
|
||||
|
||||
if border_width > 0:
|
||||
border = torch.zeros((B, output_height, border_width, C), device=image.device)
|
||||
padded_image = torch.cat((resized_image, border, pad_image), dim=2)
|
||||
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
|
||||
padded_mask[:, :, :new_width + border_width] = 0
|
||||
if latent_mask is not None:
|
||||
padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
|
||||
padded_mask[:, :, (new_width + border_width):] = resized_latent_mask
|
||||
else:
|
||||
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
|
||||
padded_mask[:, :, :new_width + border_width] = 0
|
||||
else:
|
||||
padded_image = torch.cat((resized_image, pad_image), dim=2)
|
||||
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
|
||||
padded_mask[:, :, :new_width] = 0
|
||||
if latent_mask is not None:
|
||||
padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
|
||||
padded_mask[:, :, new_width:] = resized_latent_mask
|
||||
else:
|
||||
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
|
||||
padded_mask[:, :, :new_width] = 0
|
||||
|
||||
return (padded_image, padded_mask)
|
||||
|
||||
@ -3059,3 +3074,82 @@ class ImageCropByMaskBatch:
|
||||
out_rgb = out_rgb * mask_expanded + background_color * (1 - mask_expanded)
|
||||
|
||||
return (out_rgb, out_masks)
|
||||
|
||||
class ImagePadKJ:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE", ),
|
||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
|
||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
|
||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
|
||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
|
||||
"extra_padding": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
|
||||
"pad_mode": (["edge", "color"],),
|
||||
"color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}),
|
||||
}
|
||||
, "optional": {
|
||||
"masks": ("MASK", ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", )
|
||||
RETURN_NAMES = ("images", "masks",)
|
||||
FUNCTION = "pad"
|
||||
CATEGORY = "KJNodes/image"
|
||||
DESCRIPTION = "Crops the input images based on the provided masks."
|
||||
|
||||
def pad(self, image, left, right, top, bottom, extra_padding, color, pad_mode, mask=None):
|
||||
B, H, W, C = image.shape
|
||||
|
||||
# Resize masks to image dimensions if necessary
|
||||
if mask is not None:
|
||||
BM, HM, WM = mask.shape
|
||||
if HM != H or WM != W:
|
||||
mask = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
|
||||
|
||||
# Parse background color
|
||||
bg_color = [int(x.strip())/255.0 for x in color.split(",")]
|
||||
if len(bg_color) == 1:
|
||||
bg_color = bg_color * 3 # Grayscale to RGB
|
||||
bg_color = torch.tensor(bg_color, dtype=image.dtype, device=image.device)
|
||||
|
||||
# Calculate padding sizes with extra padding
|
||||
pad_left = left + extra_padding
|
||||
pad_right = right + extra_padding
|
||||
pad_top = top + extra_padding
|
||||
pad_bottom = bottom + extra_padding
|
||||
|
||||
padded_width = W + pad_left + pad_right
|
||||
padded_height = H + pad_top + pad_bottom
|
||||
out_image = torch.zeros((B, padded_height, padded_width, C), dtype=image.dtype, device=image.device)
|
||||
|
||||
# Fill padded areas
|
||||
for b in range(B):
|
||||
if pad_mode == "edge":
|
||||
# Pad with edge color
|
||||
# Define edge pixels
|
||||
top_edge = image[b, 0, :, :]
|
||||
bottom_edge = image[b, H-1, :, :]
|
||||
left_edge = image[b, :, 0, :]
|
||||
right_edge = image[b, :, W-1, :]
|
||||
|
||||
# Fill borders with edge colors
|
||||
out_image[b, :pad_top, :, :] = top_edge.mean(dim=0)
|
||||
out_image[b, pad_top+H:, :, :] = bottom_edge.mean(dim=0)
|
||||
out_image[b, :, :pad_left, :] = left_edge.mean(dim=0)
|
||||
out_image[b, :, pad_left+W:, :] = right_edge.mean(dim=0)
|
||||
out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
|
||||
else:
|
||||
# Pad with specified background color
|
||||
out_image[b, :, :, :] = bg_color.unsqueeze(0).unsqueeze(0) # Expand for H and W dimensions
|
||||
out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
|
||||
|
||||
if mask is not None:
|
||||
out_masks = torch.zeros((BM, padded_height, padded_width), dtype=mask.dtype, device=mask.device)
|
||||
for m in range(BM):
|
||||
out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = mask[m]
|
||||
else:
|
||||
out_masks = torch.zeros((1, padded_height, padded_width), dtype=image.dtype, device=image.device)
|
||||
|
||||
return (out_image, out_masks)
|
||||
|
||||
106
nodes/nodes.py
106
nodes/nodes.py
@ -9,7 +9,7 @@ import importlib
|
||||
import model_management
|
||||
import folder_paths
|
||||
from nodes import MAX_RESOLUTION
|
||||
from comfy.utils import common_upscale, ProgressBar
|
||||
from comfy.utils import common_upscale, ProgressBar, load_torch_file
|
||||
|
||||
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
folder_paths.add_model_folder_path("kjnodes_fonts", os.path.join(script_directory, "fonts"))
|
||||
@ -1964,7 +1964,7 @@ class FluxBlockLoraLoader:
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def load_lora(self, model, strength_model, lora_name=None, opt_lora_path=None, blocks=None):
|
||||
from comfy.utils import load_torch_file
|
||||
|
||||
import comfy.lora
|
||||
|
||||
if opt_lora_path:
|
||||
@ -2299,4 +2299,104 @@ class ImageNoiseAugmentation:
|
||||
image_noise = torch.randn_like(image) * sigma[:, None, None, None]
|
||||
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
||||
image_out = image + image_noise
|
||||
return image_out,
|
||||
return image_out,
|
||||
|
||||
class VAELoaderKJ:
|
||||
@staticmethod
|
||||
def vae_list():
|
||||
vaes = folder_paths.get_filename_list("vae")
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
sdxl_taesd_enc = False
|
||||
sdxl_taesd_dec = False
|
||||
sd1_taesd_enc = False
|
||||
sd1_taesd_dec = False
|
||||
sd3_taesd_enc = False
|
||||
sd3_taesd_dec = False
|
||||
f1_taesd_enc = False
|
||||
f1_taesd_dec = False
|
||||
|
||||
for v in approx_vaes:
|
||||
if v.startswith("taesd_decoder."):
|
||||
sd1_taesd_dec = True
|
||||
elif v.startswith("taesd_encoder."):
|
||||
sd1_taesd_enc = True
|
||||
elif v.startswith("taesdxl_decoder."):
|
||||
sdxl_taesd_dec = True
|
||||
elif v.startswith("taesdxl_encoder."):
|
||||
sdxl_taesd_enc = True
|
||||
elif v.startswith("taesd3_decoder."):
|
||||
sd3_taesd_dec = True
|
||||
elif v.startswith("taesd3_encoder."):
|
||||
sd3_taesd_enc = True
|
||||
elif v.startswith("taef1_encoder."):
|
||||
f1_taesd_dec = True
|
||||
elif v.startswith("taef1_decoder."):
|
||||
f1_taesd_enc = True
|
||||
if sd1_taesd_dec and sd1_taesd_enc:
|
||||
vaes.append("taesd")
|
||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||
vaes.append("taesdxl")
|
||||
if sd3_taesd_dec and sd3_taesd_enc:
|
||||
vaes.append("taesd3")
|
||||
if f1_taesd_dec and f1_taesd_enc:
|
||||
vaes.append("taef1")
|
||||
return vaes
|
||||
|
||||
@staticmethod
|
||||
def load_taesd(name):
|
||||
sd = {}
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
|
||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||
|
||||
enc = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||
for k in enc:
|
||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||
|
||||
dec = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||
for k in dec:
|
||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||
|
||||
if name == "taesd":
|
||||
sd["vae_scale"] = torch.tensor(0.18215)
|
||||
sd["vae_shift"] = torch.tensor(0.0)
|
||||
elif name == "taesdxl":
|
||||
sd["vae_scale"] = torch.tensor(0.13025)
|
||||
sd["vae_shift"] = torch.tensor(0.0)
|
||||
elif name == "taesd3":
|
||||
sd["vae_scale"] = torch.tensor(1.5305)
|
||||
sd["vae_shift"] = torch.tensor(0.0609)
|
||||
elif name == "taef1":
|
||||
sd["vae_scale"] = torch.tensor(0.3611)
|
||||
sd["vae_shift"] = torch.tensor(0.1159)
|
||||
return sd
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": { "vae_name": (s.vae_list(), ),
|
||||
"device": (["main_device", "cpu"],),
|
||||
"weight_dtype": (["bf16", "fp16", "fp32" ],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
CATEGORY = "KJNodes/vae"
|
||||
|
||||
def load_vae(self, vae_name, device, weight_dtype):
|
||||
from comfy.sd import VAE
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[weight_dtype]
|
||||
if device == "main_device":
|
||||
device = model_management.get_torch_device()
|
||||
elif device == "cpu":
|
||||
device = torch.device("cpu")
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = load_torch_file(vae_path)
|
||||
vae = VAE(sd=sd, device=device, dtype=dtype)
|
||||
return (vae,)
|
||||
Loading…
x
Reference in New Issue
Block a user