mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-30 00:27:16 +08:00
Proper resadapter patcher
This commit is contained in:
parent
9a6aaa6518
commit
a7ea300a57
36
nodes.py
36
nodes.py
@ -1144,17 +1144,21 @@ class VRAM_Debug:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"image_passthrough": ("IMAGE",),
|
|
||||||
"empty_cache": ("BOOLEAN", {"default": True}),
|
"empty_cache": ("BOOLEAN", {"default": True}),
|
||||||
"unload_all_models": ("BOOLEAN", {"default": False}),
|
"unload_all_models": ("BOOLEAN", {"default": False}),
|
||||||
},
|
},
|
||||||
|
"optional":{
|
||||||
|
"image_passthrough": ("IMAGE",),
|
||||||
|
"model_passthrough": ("MODEL",),
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("IMAGE", "INT", "INT",)
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "MODEL","INT", "INT",)
|
||||||
RETURN_NAMES = ("image_passthrough", "freemem_before", "freemem_after")
|
RETURN_NAMES = ("image_passthrough", "freemem_before", "freemem_after")
|
||||||
FUNCTION = "VRAMdebug"
|
FUNCTION = "VRAMdebug"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
def VRAMdebug(self, image_passthrough, empty_cache, unload_all_models):
|
def VRAMdebug(self, empty_cache, unload_all_models,image_passthrough=None, model_passthrough=None):
|
||||||
freemem_before = comfy.model_management.get_free_memory()
|
freemem_before = comfy.model_management.get_free_memory()
|
||||||
print("VRAMdebug: free memory before: ", freemem_before)
|
print("VRAMdebug: free memory before: ", freemem_before)
|
||||||
if empty_cache:
|
if empty_cache:
|
||||||
@ -1164,7 +1168,7 @@ class VRAM_Debug:
|
|||||||
freemem_after = comfy.model_management.get_free_memory()
|
freemem_after = comfy.model_management.get_free_memory()
|
||||||
print("VRAMdebug: free memory after: ", freemem_after)
|
print("VRAMdebug: free memory after: ", freemem_after)
|
||||||
print("VRAMdebug: freed memory: ", freemem_after - freemem_before)
|
print("VRAMdebug: freed memory: ", freemem_after - freemem_before)
|
||||||
return (image_passthrough, freemem_before, freemem_after)
|
return (image_passthrough, model_passthrough, freemem_before, freemem_after)
|
||||||
|
|
||||||
class AnyType(str):
|
class AnyType(str):
|
||||||
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
||||||
@ -3812,12 +3816,26 @@ class LoadResAdapterNormalization:
|
|||||||
|
|
||||||
def load_res_adapter(self, model, resadapter_path):
|
def load_res_adapter(self, model, resadapter_path):
|
||||||
resadapter_path = f"{folder_paths.get_folder_paths('checkpoints')[0]}/{resadapter_path}"
|
resadapter_path = f"{folder_paths.get_folder_paths('checkpoints')[0]}/{resadapter_path}"
|
||||||
|
|
||||||
if os.path.exists(resadapter_path):
|
if os.path.exists(resadapter_path):
|
||||||
norm_state_dict = comfy.utils.load_torch_file(resadapter_path)
|
prefix_to_remove = 'diffusion_model.'
|
||||||
model_clone = model.clone()
|
model_clone = model.clone()
|
||||||
model_clone.model.load_state_dict(norm_state_dict, strict=False)
|
norm_state_dict = comfy.utils.load_torch_file(resadapter_path)
|
||||||
return (model_clone, )
|
new_values = {key[len(prefix_to_remove):]: value for key, value in norm_state_dict.items() if key.startswith(prefix_to_remove)}
|
||||||
|
|
||||||
|
# Replace the values for the keys in the model's state dict
|
||||||
|
for key in model.model.diffusion_model.state_dict().keys():
|
||||||
|
if key in new_values:
|
||||||
|
original_tensor = model.model.diffusion_model.state_dict()[key]
|
||||||
|
new_tensor = new_values[key].to(model.model.diffusion_model.dtype)
|
||||||
|
if original_tensor.shape == new_tensor.shape:
|
||||||
|
model_clone.add_object_patch(f"diffusion_model.{key}.data", new_tensor)
|
||||||
|
print(f"Replaced key: {key}")
|
||||||
|
else:
|
||||||
|
print(f"Shape mismatch for key: {key}. Did not replace.")
|
||||||
|
|
||||||
|
|
||||||
|
return (model_clone, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user