Add Effnet encode node for stable-cascade img2img

This commit is contained in:
kijai 2024-02-18 16:18:56 +02:00
parent 450bcd417b
commit 7dc5be88aa

View File

@ -3575,6 +3575,67 @@ class ImageUpscaleWithModelBatched:
return (t,)
import torchvision
from torch import nn
from safetensors import safe_open
class EfficientNetEncoder(nn.Module):
def __init__(self, c_latent=16):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval()
self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
def forward(self, x):
return self.mapper(self.backbone(x))
class EffnetEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "KJNodes"
def encode(self, image):
device = comfy.model_management.get_torch_device()
image = image.permute(0, 3, 1, 2).to(device)
effnet_preprocess = torchvision.transforms.Compose([
torchvision.transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)
])
image = effnet_preprocess(image)
effnet = EfficientNetEncoder()
effnet_checkpoint = os.path.join(folder_paths.models_dir,"vae", "Stable-cascade","effnet_encoder.safetensors")
if not os.path.exists(effnet_checkpoint):
try:
from huggingface_hub import snapshot_download
download_to = os.path.join(folder_paths.models_dir,'vae', "Stable-cascade")
snapshot_download(repo_id="stabilityai/stable-cascade", allow_patterns=["effnet_encoder.safetensors"],
local_dir=download_to, local_dir_use_symlinks=False)
except:
raise Exception("Model not found and download failed. (https://huggingface.co/stabilityai/stable-cascade, effnet_encoder.safetensors)")
effnet_state_dict = {}
with safe_open(effnet_checkpoint, framework="pt", device="cpu") as f:
for key in f.keys():
effnet_state_dict[key] = f.get_tensor(key)
effnet.load_state_dict(effnet_state_dict if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
effnet.eval().requires_grad_(False).to(device)
t = effnet(image)
return ({"samples":t}, )
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
"FloatConstant": FloatConstant,
@ -3640,7 +3701,8 @@ NODE_CLASS_MAPPINGS = {
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
"CondPassThrough": CondPassThrough,
"ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched,
"ScaleBatchPromptSchedule": ScaleBatchPromptSchedule
"ScaleBatchPromptSchedule": ScaleBatchPromptSchedule,
"EffnetEncode": EffnetEncode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -3706,5 +3768,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
"CondPassThrough": "CondPassThrough",
"ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched",
"ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule"
"ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule",
"EffnetEncode": "EffnetEncode"
}