mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 23:04:39 +08:00
Add Effnet encode node for stable-cascade img2img
This commit is contained in:
parent
450bcd417b
commit
7dc5be88aa
67
nodes.py
67
nodes.py
@ -3575,6 +3575,67 @@ class ImageUpscaleWithModelBatched:
|
||||
|
||||
return (t,)
|
||||
|
||||
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from safetensors import safe_open
|
||||
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
def __init__(self, c_latent=16):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mapper(self.backbone(x))
|
||||
|
||||
class EffnetEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "KJNodes"
|
||||
|
||||
def encode(self, image):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
image = image.permute(0, 3, 1, 2).to(device)
|
||||
effnet_preprocess = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Normalize(
|
||||
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
|
||||
)
|
||||
])
|
||||
image = effnet_preprocess(image)
|
||||
effnet = EfficientNetEncoder()
|
||||
|
||||
effnet_checkpoint = os.path.join(folder_paths.models_dir,"vae", "Stable-cascade","effnet_encoder.safetensors")
|
||||
if not os.path.exists(effnet_checkpoint):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
download_to = os.path.join(folder_paths.models_dir,'vae', "Stable-cascade")
|
||||
snapshot_download(repo_id="stabilityai/stable-cascade", allow_patterns=["effnet_encoder.safetensors"],
|
||||
local_dir=download_to, local_dir_use_symlinks=False)
|
||||
except:
|
||||
raise Exception("Model not found and download failed. (https://huggingface.co/stabilityai/stable-cascade, effnet_encoder.safetensors)")
|
||||
|
||||
effnet_state_dict = {}
|
||||
with safe_open(effnet_checkpoint, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
effnet_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
effnet.load_state_dict(effnet_state_dict if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
|
||||
effnet.eval().requires_grad_(False).to(device)
|
||||
t = effnet(image)
|
||||
return ({"samples":t}, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"INTConstant": INTConstant,
|
||||
"FloatConstant": FloatConstant,
|
||||
@ -3640,7 +3701,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
|
||||
"CondPassThrough": CondPassThrough,
|
||||
"ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched,
|
||||
"ScaleBatchPromptSchedule": ScaleBatchPromptSchedule
|
||||
"ScaleBatchPromptSchedule": ScaleBatchPromptSchedule,
|
||||
"EffnetEncode": EffnetEncode
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"INTConstant": "INT Constant",
|
||||
@ -3706,5 +3768,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
|
||||
"CondPassThrough": "CondPassThrough",
|
||||
"ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched",
|
||||
"ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule"
|
||||
"ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule",
|
||||
"EffnetEncode": "EffnetEncode"
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user