mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
convert nodes_lora_extract.py to V3 schema (#10182)
This commit is contained in:
parent
2ba8d7cce8
commit
989f715d92
@ -5,6 +5,8 @@ import folder_paths
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
|||||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
return output_sd
|
return output_sd
|
||||||
|
|
||||||
class LoraSave:
|
class LoraSave(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoraSave",
|
||||||
|
display_name="Extract and Save Lora",
|
||||||
|
category="_for_testing",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||||
|
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
|
||||||
|
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
|
||||||
|
io.Boolean.Input("bias_diff", default=True),
|
||||||
|
io.Model.Input(
|
||||||
|
"model_diff",
|
||||||
|
tooltip="The ModelSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
io.Clip.Input(
|
||||||
|
"text_encoder_diff",
|
||||||
|
tooltip="The CLIPSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
|
||||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
|
||||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
|
||||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
|
||||||
},
|
|
||||||
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
|
||||||
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
|
|
||||||
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
|
||||||
if model_diff is None and text_encoder_diff is None:
|
if model_diff is None and text_encoder_diff is None:
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
lora_type = LORA_TYPES.get(lora_type)
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||||
|
|
||||||
output_sd = {}
|
output_sd = {}
|
||||||
if model_diff is not None:
|
if model_diff is not None:
|
||||||
@ -108,12 +118,16 @@ class LoraSave:
|
|||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"LoraSave": LoraSave
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class LoraSaveExtension(ComfyExtension):
|
||||||
"LoraSave": "Extract and Save Lora"
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LoraSave,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LoraSaveExtension:
|
||||||
|
return LoraSaveExtension()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user