Add ModelSaveKJ

node to save a model with the prefix you want
This commit is contained in:
kijai 2024-09-24 23:27:54 +03:00
parent 6d119fda33
commit c31fa9f438
2 changed files with 51 additions and 1 deletions

View File

@ -105,6 +105,7 @@ NODE_CONFIG = {
"EmptyLatentImagePresets": {"class": EmptyLatentImagePresets, "name": "Empty Latent Image Presets"},
"EmptyLatentImageCustomPresets": {"class": EmptyLatentImageCustomPresets, "name": "Empty Latent Image Custom Presets"},
"ModelPassThrough": {"class": ModelPassThrough, "name": "ModelPass"},
"ModelSaveKJ": {"class": ModelSaveKJ, "name": "Model Save KJ"},
"SetShakkerLabsUnionControlNetType": {"class": SetShakkerLabsUnionControlNetType, "name": "Set Shakker Labs Union ControlNet Type"},
#audioscheduler stuff
"NormalizedAmplitudeToMask": {"class": NormalizedAmplitudeToMask},

View File

@ -2070,4 +2070,53 @@ class SetShakkerLabsUnionControlNetType:
else:
control_net.set_extra_arg("control_type", [])
return (control_net,)
return (control_net,)
class ModelSaveKJ:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),
"model_key_prefix": ("STRING", {"default": "model.diffusion_model."}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, filename_prefix, model_key_prefix, prompt=None, extra_pnginfo=None):
from comfy.utils import save_torch_file
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
load_models = [model]
model_management.load_models_gpu(load_models, force_patch_weights=True)
default_prefix = "model.diffusion_model."
sd = model.model.state_dict_for_saving(None, None, None)
new_sd = {}
for k in sd:
if k.startswith(default_prefix):
new_key = model_key_prefix + k[len(default_prefix):]
else:
new_key = k # In case the key doesn't start with the default prefix, keep it unchanged
t = sd[k]
if not t.is_contiguous():
t = t.contiguous()
new_sd[new_key] = t
print(full_output_folder)
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder)
save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint))
return {}