mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-10 06:24:26 +08:00
convert nodes_hypernetwork.py to V3 schema (#10583)
This commit is contained in:
parent
88df172790
commit
1f3f7a2823
@ -2,6 +2,9 @@ import comfy.utils
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetwork_patch(path, strength):
|
def load_hypernetwork_patch(path, strength):
|
||||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
|
|
||||||
return hypernetwork_patch(out, strength)
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
class HypernetworkLoader:
|
class HypernetworkLoader(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
node_id="HypernetworkLoader",
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
category="loaders",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Model.Input("model"),
|
||||||
FUNCTION = "load_hypernetwork"
|
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||||
|
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
@classmethod
|
||||||
|
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
|
||||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
model_hypernetwork.set_model_attn1_patch(patch)
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
model_hypernetwork.set_model_attn2_patch(patch)
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
return (model_hypernetwork,)
|
return IO.NodeOutput(model_hypernetwork)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
load_hypernetwork = execute # TODO: remove
|
||||||
"HypernetworkLoader": HypernetworkLoader
|
|
||||||
}
|
|
||||||
|
class HyperNetworkExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
HypernetworkLoader,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HyperNetworkExtension:
|
||||||
|
return HyperNetworkExtension()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user