From c31fa9f4385789500d066cc53413851ee294fdd1 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 24 Sep 2024 23:27:54 +0300 Subject: [PATCH] Add ModelSaveKJ node to save a model with the prefix you want --- __init__.py | 1 + nodes/nodes.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index 74fccf0..41ed9c5 100644 --- a/__init__.py +++ b/__init__.py @@ -105,6 +105,7 @@ NODE_CONFIG = { "EmptyLatentImagePresets": {"class": EmptyLatentImagePresets, "name": "Empty Latent Image Presets"}, "EmptyLatentImageCustomPresets": {"class": EmptyLatentImageCustomPresets, "name": "Empty Latent Image Custom Presets"}, "ModelPassThrough": {"class": ModelPassThrough, "name": "ModelPass"}, + "ModelSaveKJ": {"class": ModelSaveKJ, "name": "Model Save KJ"}, "SetShakkerLabsUnionControlNetType": {"class": SetShakkerLabsUnionControlNetType, "name": "Set Shakker Labs Union ControlNet Type"}, #audioscheduler stuff "NormalizedAmplitudeToMask": {"class": NormalizedAmplitudeToMask}, diff --git a/nodes/nodes.py b/nodes/nodes.py index e7a92f3..300d848 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2070,4 +2070,53 @@ class SetShakkerLabsUnionControlNetType: else: control_net.set_extra_arg("control_type", []) - return (control_net,) \ No newline at end of file + return (control_net,) + +class ModelSaveKJ: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}), + "model_key_prefix": ("STRING", {"default": "model.diffusion_model."}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, model, filename_prefix, model_key_prefix, prompt=None, extra_pnginfo=None): + from comfy.utils import save_torch_file + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + load_models = [model] + + model_management.load_models_gpu(load_models, force_patch_weights=True) + default_prefix = "model.diffusion_model." + + sd = model.model.state_dict_for_saving(None, None, None) + + new_sd = {} + for k in sd: + if k.startswith(default_prefix): + new_key = model_key_prefix + k[len(default_prefix):] + else: + new_key = k # In case the key doesn't start with the default prefix, keep it unchanged + t = sd[k] + if not t.is_contiguous(): + t = t.contiguous() + new_sd[new_key] = t + print(full_output_folder) + if not os.path.exists(full_output_folder): + os.makedirs(full_output_folder) + save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint)) + return {} + + \ No newline at end of file