mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-27 22:21:35 +08:00
Proper resadapter patcher
This commit is contained in:
parent
9a6aaa6518
commit
a7ea300a57
36
nodes.py
36
nodes.py
@ -1144,17 +1144,21 @@ class VRAM_Debug:
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image_passthrough": ("IMAGE",),
|
||||
"empty_cache": ("BOOLEAN", {"default": True}),
|
||||
"unload_all_models": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
},
|
||||
"optional":{
|
||||
"image_passthrough": ("IMAGE",),
|
||||
"model_passthrough": ("MODEL",),
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE", "INT", "INT",)
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MODEL","INT", "INT",)
|
||||
RETURN_NAMES = ("image_passthrough", "freemem_before", "freemem_after")
|
||||
FUNCTION = "VRAMdebug"
|
||||
CATEGORY = "KJNodes"
|
||||
|
||||
def VRAMdebug(self, image_passthrough, empty_cache, unload_all_models):
|
||||
def VRAMdebug(self, empty_cache, unload_all_models,image_passthrough=None, model_passthrough=None):
|
||||
freemem_before = comfy.model_management.get_free_memory()
|
||||
print("VRAMdebug: free memory before: ", freemem_before)
|
||||
if empty_cache:
|
||||
@ -1164,7 +1168,7 @@ class VRAM_Debug:
|
||||
freemem_after = comfy.model_management.get_free_memory()
|
||||
print("VRAMdebug: free memory after: ", freemem_after)
|
||||
print("VRAMdebug: freed memory: ", freemem_after - freemem_before)
|
||||
return (image_passthrough, freemem_before, freemem_after)
|
||||
return (image_passthrough, model_passthrough, freemem_before, freemem_after)
|
||||
|
||||
class AnyType(str):
|
||||
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
||||
@ -3812,12 +3816,26 @@ class LoadResAdapterNormalization:
|
||||
|
||||
def load_res_adapter(self, model, resadapter_path):
|
||||
resadapter_path = f"{folder_paths.get_folder_paths('checkpoints')[0]}/{resadapter_path}"
|
||||
|
||||
|
||||
if os.path.exists(resadapter_path):
|
||||
norm_state_dict = comfy.utils.load_torch_file(resadapter_path)
|
||||
prefix_to_remove = 'diffusion_model.'
|
||||
model_clone = model.clone()
|
||||
model_clone.model.load_state_dict(norm_state_dict, strict=False)
|
||||
return (model_clone, )
|
||||
norm_state_dict = comfy.utils.load_torch_file(resadapter_path)
|
||||
new_values = {key[len(prefix_to_remove):]: value for key, value in norm_state_dict.items() if key.startswith(prefix_to_remove)}
|
||||
|
||||
# Replace the values for the keys in the model's state dict
|
||||
for key in model.model.diffusion_model.state_dict().keys():
|
||||
if key in new_values:
|
||||
original_tensor = model.model.diffusion_model.state_dict()[key]
|
||||
new_tensor = new_values[key].to(model.model.diffusion_model.dtype)
|
||||
if original_tensor.shape == new_tensor.shape:
|
||||
model_clone.add_object_patch(f"diffusion_model.{key}.data", new_tensor)
|
||||
print(f"Replaced key: {key}")
|
||||
else:
|
||||
print(f"Shape mismatch for key: {key}. Did not replace.")
|
||||
|
||||
|
||||
return (model_clone, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"INTConstant": INTConstant,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user