mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-04-09 16:06:58 +08:00
Add ModelSaveKJ
node to save a model with the prefix you want
This commit is contained in:
parent
6d119fda33
commit
c31fa9f438
@ -105,6 +105,7 @@ NODE_CONFIG = {
|
||||
"EmptyLatentImagePresets": {"class": EmptyLatentImagePresets, "name": "Empty Latent Image Presets"},
|
||||
"EmptyLatentImageCustomPresets": {"class": EmptyLatentImageCustomPresets, "name": "Empty Latent Image Custom Presets"},
|
||||
"ModelPassThrough": {"class": ModelPassThrough, "name": "ModelPass"},
|
||||
"ModelSaveKJ": {"class": ModelSaveKJ, "name": "Model Save KJ"},
|
||||
"SetShakkerLabsUnionControlNetType": {"class": SetShakkerLabsUnionControlNetType, "name": "Set Shakker Labs Union ControlNet Type"},
|
||||
#audioscheduler stuff
|
||||
"NormalizedAmplitudeToMask": {"class": NormalizedAmplitudeToMask},
|
||||
|
||||
@ -2070,4 +2070,53 @@ class SetShakkerLabsUnionControlNetType:
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
return (control_net,)
|
||||
return (control_net,)
|
||||
|
||||
class ModelSaveKJ:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),
|
||||
"model_key_prefix": ("STRING", {"default": "model.diffusion_model."}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, model, filename_prefix, model_key_prefix, prompt=None, extra_pnginfo=None):
|
||||
from comfy.utils import save_torch_file
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
load_models = [model]
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
default_prefix = "model.diffusion_model."
|
||||
|
||||
sd = model.model.state_dict_for_saving(None, None, None)
|
||||
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
if k.startswith(default_prefix):
|
||||
new_key = model_key_prefix + k[len(default_prefix):]
|
||||
else:
|
||||
new_key = k # In case the key doesn't start with the default prefix, keep it unchanged
|
||||
t = sd[k]
|
||||
if not t.is_contiguous():
|
||||
t = t.contiguous()
|
||||
new_sd[new_key] = t
|
||||
print(full_output_folder)
|
||||
if not os.path.exists(full_output_folder):
|
||||
os.makedirs(full_output_folder)
|
||||
save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint))
|
||||
return {}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user