From a7ea300a5736f6a9fd027b167633cfa08a6f869e Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 14 Mar 2024 21:24:45 +0200 Subject: [PATCH] Proper resadapter patcher --- nodes.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/nodes.py b/nodes.py index 59f0c46..7247304 100644 --- a/nodes.py +++ b/nodes.py @@ -1144,17 +1144,21 @@ class VRAM_Debug: def INPUT_TYPES(s): return { "required": { - "image_passthrough": ("IMAGE",), "empty_cache": ("BOOLEAN", {"default": True}), "unload_all_models": ("BOOLEAN", {"default": False}), - }, + }, + "optional":{ + "image_passthrough": ("IMAGE",), + "model_passthrough": ("MODEL",), } - RETURN_TYPES = ("IMAGE", "INT", "INT",) + } + + RETURN_TYPES = ("IMAGE", "MODEL","INT", "INT",) RETURN_NAMES = ("image_passthrough", "freemem_before", "freemem_after") FUNCTION = "VRAMdebug" CATEGORY = "KJNodes" - def VRAMdebug(self, image_passthrough, empty_cache, unload_all_models): + def VRAMdebug(self, empty_cache, unload_all_models,image_passthrough=None, model_passthrough=None): freemem_before = comfy.model_management.get_free_memory() print("VRAMdebug: free memory before: ", freemem_before) if empty_cache: @@ -1164,7 +1168,7 @@ class VRAM_Debug: freemem_after = comfy.model_management.get_free_memory() print("VRAMdebug: free memory after: ", freemem_after) print("VRAMdebug: freed memory: ", freemem_after - freemem_before) - return (image_passthrough, freemem_before, freemem_after) + return (image_passthrough, model_passthrough, freemem_before, freemem_after) class AnyType(str): """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" @@ -3812,12 +3816,26 @@ class LoadResAdapterNormalization: def load_res_adapter(self, model, resadapter_path): resadapter_path = f"{folder_paths.get_folder_paths('checkpoints')[0]}/{resadapter_path}" - + if os.path.exists(resadapter_path): - norm_state_dict = comfy.utils.load_torch_file(resadapter_path) + prefix_to_remove = 'diffusion_model.' model_clone = model.clone() - model_clone.model.load_state_dict(norm_state_dict, strict=False) - return (model_clone, ) + norm_state_dict = comfy.utils.load_torch_file(resadapter_path) + new_values = {key[len(prefix_to_remove):]: value for key, value in norm_state_dict.items() if key.startswith(prefix_to_remove)} + + # Replace the values for the keys in the model's state dict + for key in model.model.diffusion_model.state_dict().keys(): + if key in new_values: + original_tensor = model.model.diffusion_model.state_dict()[key] + new_tensor = new_values[key].to(model.model.diffusion_model.dtype) + if original_tensor.shape == new_tensor.shape: + model_clone.add_object_patch(f"diffusion_model.{key}.data", new_tensor) + print(f"Replaced key: {key}") + else: + print(f"Shape mismatch for key: {key}. Did not replace.") + + + return (model_clone, ) NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant,