From 7dc5be88aafe41b0450cdbed3afbdc10c9718bd4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 18 Feb 2024 16:18:56 +0200 Subject: [PATCH] Add Effnet encode node for stable-cascade img2img --- nodes.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 77446a0..8fc61e2 100644 --- a/nodes.py +++ b/nodes.py @@ -3575,6 +3575,67 @@ class ImageUpscaleWithModelBatched: return (t,) + +import torchvision +from torch import nn +from safetensors import safe_open + +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) + +class EffnetEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "KJNodes" + + def encode(self, image): + device = comfy.model_management.get_torch_device() + + image = image.permute(0, 3, 1, 2).to(device) + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + image = effnet_preprocess(image) + effnet = EfficientNetEncoder() + + effnet_checkpoint = os.path.join(folder_paths.models_dir,"vae", "Stable-cascade","effnet_encoder.safetensors") + if not os.path.exists(effnet_checkpoint): + try: + from huggingface_hub import snapshot_download + download_to = os.path.join(folder_paths.models_dir,'vae', "Stable-cascade") + snapshot_download(repo_id="stabilityai/stable-cascade", allow_patterns=["effnet_encoder.safetensors"], + local_dir=download_to, local_dir_use_symlinks=False) + except: + raise Exception("Model not found and download failed. (https://huggingface.co/stabilityai/stable-cascade, effnet_encoder.safetensors)") + + effnet_state_dict = {} + with safe_open(effnet_checkpoint, framework="pt", device="cpu") as f: + for key in f.keys(): + effnet_state_dict[key] = f.get_tensor(key) + + effnet.load_state_dict(effnet_state_dict if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(device) + t = effnet(image) + return ({"samples":t}, ) + + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -3640,7 +3701,8 @@ NODE_CLASS_MAPPINGS = { "GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch, "CondPassThrough": CondPassThrough, "ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched, - "ScaleBatchPromptSchedule": ScaleBatchPromptSchedule + "ScaleBatchPromptSchedule": ScaleBatchPromptSchedule, + "EffnetEncode": EffnetEncode } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -3706,5 +3768,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch", "CondPassThrough": "CondPassThrough", "ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched", - "ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule" + "ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule", + "EffnetEncode": "EffnetEncode" } \ No newline at end of file