Proper resadapter patcher

This commit is contained in:
Kijai 2024-03-14 21:24:45 +02:00
parent 9a6aaa6518
commit a7ea300a57

View File

@ -1144,17 +1144,21 @@ class VRAM_Debug:
def INPUT_TYPES(s):
return {
"required": {
"image_passthrough": ("IMAGE",),
"empty_cache": ("BOOLEAN", {"default": True}),
"unload_all_models": ("BOOLEAN", {"default": False}),
},
},
"optional":{
"image_passthrough": ("IMAGE",),
"model_passthrough": ("MODEL",),
}
RETURN_TYPES = ("IMAGE", "INT", "INT",)
}
RETURN_TYPES = ("IMAGE", "MODEL","INT", "INT",)
RETURN_NAMES = ("image_passthrough", "freemem_before", "freemem_after")
FUNCTION = "VRAMdebug"
CATEGORY = "KJNodes"
def VRAMdebug(self, image_passthrough, empty_cache, unload_all_models):
def VRAMdebug(self, empty_cache, unload_all_models,image_passthrough=None, model_passthrough=None):
freemem_before = comfy.model_management.get_free_memory()
print("VRAMdebug: free memory before: ", freemem_before)
if empty_cache:
@ -1164,7 +1168,7 @@ class VRAM_Debug:
freemem_after = comfy.model_management.get_free_memory()
print("VRAMdebug: free memory after: ", freemem_after)
print("VRAMdebug: freed memory: ", freemem_after - freemem_before)
return (image_passthrough, freemem_before, freemem_after)
return (image_passthrough, model_passthrough, freemem_before, freemem_after)
class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
@ -3812,12 +3816,26 @@ class LoadResAdapterNormalization:
def load_res_adapter(self, model, resadapter_path):
resadapter_path = f"{folder_paths.get_folder_paths('checkpoints')[0]}/{resadapter_path}"
if os.path.exists(resadapter_path):
norm_state_dict = comfy.utils.load_torch_file(resadapter_path)
prefix_to_remove = 'diffusion_model.'
model_clone = model.clone()
model_clone.model.load_state_dict(norm_state_dict, strict=False)
return (model_clone, )
norm_state_dict = comfy.utils.load_torch_file(resadapter_path)
new_values = {key[len(prefix_to_remove):]: value for key, value in norm_state_dict.items() if key.startswith(prefix_to_remove)}
# Replace the values for the keys in the model's state dict
for key in model.model.diffusion_model.state_dict().keys():
if key in new_values:
original_tensor = model.model.diffusion_model.state_dict()[key]
new_tensor = new_values[key].to(model.model.diffusion_model.dtype)
if original_tensor.shape == new_tensor.shape:
model_clone.add_object_patch(f"diffusion_model.{key}.data", new_tensor)
print(f"Replaced key: {key}")
else:
print(f"Shape mismatch for key: {key}. Did not replace.")
return (model_clone, )
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,