mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-13 16:15:45 +08:00
Add Effnet encode node for stable-cascade img2img
This commit is contained in:
parent
450bcd417b
commit
7dc5be88aa
67
nodes.py
67
nodes.py
@ -3575,6 +3575,67 @@ class ImageUpscaleWithModelBatched:
|
|||||||
|
|
||||||
return (t,)
|
return (t,)
|
||||||
|
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
class EfficientNetEncoder(nn.Module):
|
||||||
|
def __init__(self, c_latent=16):
|
||||||
|
super().__init__()
|
||||||
|
self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval()
|
||||||
|
self.mapper = nn.Sequential(
|
||||||
|
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.mapper(self.backbone(x))
|
||||||
|
|
||||||
|
class EffnetEncode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
|
def encode(self, image):
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
image = image.permute(0, 3, 1, 2).to(device)
|
||||||
|
effnet_preprocess = torchvision.transforms.Compose([
|
||||||
|
torchvision.transforms.Normalize(
|
||||||
|
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
|
||||||
|
)
|
||||||
|
])
|
||||||
|
image = effnet_preprocess(image)
|
||||||
|
effnet = EfficientNetEncoder()
|
||||||
|
|
||||||
|
effnet_checkpoint = os.path.join(folder_paths.models_dir,"vae", "Stable-cascade","effnet_encoder.safetensors")
|
||||||
|
if not os.path.exists(effnet_checkpoint):
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
download_to = os.path.join(folder_paths.models_dir,'vae', "Stable-cascade")
|
||||||
|
snapshot_download(repo_id="stabilityai/stable-cascade", allow_patterns=["effnet_encoder.safetensors"],
|
||||||
|
local_dir=download_to, local_dir_use_symlinks=False)
|
||||||
|
except:
|
||||||
|
raise Exception("Model not found and download failed. (https://huggingface.co/stabilityai/stable-cascade, effnet_encoder.safetensors)")
|
||||||
|
|
||||||
|
effnet_state_dict = {}
|
||||||
|
with safe_open(effnet_checkpoint, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
effnet_state_dict[key] = f.get_tensor(key)
|
||||||
|
|
||||||
|
effnet.load_state_dict(effnet_state_dict if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
|
||||||
|
effnet.eval().requires_grad_(False).to(device)
|
||||||
|
t = effnet(image)
|
||||||
|
return ({"samples":t}, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
"FloatConstant": FloatConstant,
|
"FloatConstant": FloatConstant,
|
||||||
@ -3640,7 +3701,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
|
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
|
||||||
"CondPassThrough": CondPassThrough,
|
"CondPassThrough": CondPassThrough,
|
||||||
"ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched,
|
"ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched,
|
||||||
"ScaleBatchPromptSchedule": ScaleBatchPromptSchedule
|
"ScaleBatchPromptSchedule": ScaleBatchPromptSchedule,
|
||||||
|
"EffnetEncode": EffnetEncode
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -3706,5 +3768,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
|
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
|
||||||
"CondPassThrough": "CondPassThrough",
|
"CondPassThrough": "CondPassThrough",
|
||||||
"ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched",
|
"ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched",
|
||||||
"ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule"
|
"ScaleBatchPromptSchedule": "ScaleBatchPromptSchedule",
|
||||||
|
"EffnetEncode": "EffnetEncode"
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user