mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
* execution: Roll the UI cache into the outputs Currently the UI cache is parallel to the output cache with expectations of being a content superset of the output cache. At the same time the UI and output cache are maintained completely seperately, making it awkward to free the output cache content without changing the behaviour of the UI cache. There are two actual users (getters) of the UI cache. The first is the case of a direct content hit on the output cache when executing a node. This case is very naturally handled by merging the UI and outputs cache. The second case is the history JSON generation at the end of the prompt. This currently works by asking the cache for all_node_ids and then pulling the cache contents for those nodes. all_node_ids is the nodes of the dynamic prompt. So fold the UI cache into the output cache. The current UI cache setter now writes to a prompt-scope dict. When the output cache is set, just get this value from the dict and tuple up with the outputs. When generating the history, simply iterate prompt-scope dict. This prepares support for more complex caching strategies (like RAM pressure caching) where less than 1 workflow will be cached and it will be desirable to keep the UI cache and output cache in sync. * sd: Implement RAM getter for VAE * model_patcher: Implement RAM getter for ModelPatcher * sd: Implement RAM getter for CLIP * Implement RAM Pressure cache Implement a cache sensitive to RAM pressure. When RAM headroom drops down below a certain threshold, evict RAM-expensive nodes from the cache. Models and tensors are measured directly for RAM usage. An OOM score is then computed based on the RAM usage of the node. Note the due to indirection through shared objects (like a model patcher), multiple nodes can account the same RAM as their individual usage. The intent is this will free chains of nodes particularly model loaders and associate loras as they all score similar and are sorted in close to each other. Has a bias towards unloading model nodes mid flow while being able to keep results like text encodings and VAE. * execution: Convert the cache entry to NamedTuple As commented in review. Convert this to a named tuple and abstract away the tuple type completely from graph.py.
1300 lines
54 KiB
Python
1300 lines
54 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import copy
|
|
import inspect
|
|
import logging
|
|
import math
|
|
import uuid
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
|
|
import comfy.float
|
|
import comfy.hooks
|
|
import comfy.lora
|
|
import comfy.model_management
|
|
import comfy.patcher_extension
|
|
import comfy.utils
|
|
from comfy.comfy_types import UnetWrapperFunction
|
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
|
|
|
|
|
def string_to_seed(data):
|
|
crc = 0xFFFFFFFF
|
|
for byte in data:
|
|
if isinstance(byte, str):
|
|
byte = ord(byte)
|
|
crc ^= byte
|
|
for _ in range(8):
|
|
if crc & 1:
|
|
crc = (crc >> 1) ^ 0xEDB88320
|
|
else:
|
|
crc >>= 1
|
|
return crc ^ 0xFFFFFFFF
|
|
|
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
|
to = model_options["transformer_options"].copy()
|
|
|
|
if "patches_replace" not in to:
|
|
to["patches_replace"] = {}
|
|
else:
|
|
to["patches_replace"] = to["patches_replace"].copy()
|
|
|
|
if name not in to["patches_replace"]:
|
|
to["patches_replace"][name] = {}
|
|
else:
|
|
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
|
|
|
if transformer_index is not None:
|
|
block = (block_name, number, transformer_index)
|
|
else:
|
|
block = (block_name, number)
|
|
to["patches_replace"][name][block] = patch
|
|
model_options["transformer_options"] = to
|
|
return model_options
|
|
|
|
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
|
|
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
|
if disable_cfg1_optimization:
|
|
model_options["disable_cfg1_optimization"] = True
|
|
return model_options
|
|
|
|
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
|
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
|
if disable_cfg1_optimization:
|
|
model_options["disable_cfg1_optimization"] = True
|
|
return model_options
|
|
|
|
def create_model_options_clone(orig_model_options: dict):
|
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
|
|
|
def create_hook_patches_clone(orig_hook_patches):
|
|
new_hook_patches = {}
|
|
for hook_ref in orig_hook_patches:
|
|
new_hook_patches[hook_ref] = {}
|
|
for k in orig_hook_patches[hook_ref]:
|
|
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
|
return new_hook_patches
|
|
|
|
def wipe_lowvram_weight(m):
|
|
if hasattr(m, "prev_comfy_cast_weights"):
|
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
|
del m.prev_comfy_cast_weights
|
|
|
|
if hasattr(m, "weight_function"):
|
|
m.weight_function = []
|
|
|
|
if hasattr(m, "bias_function"):
|
|
m.bias_function = []
|
|
|
|
def move_weight_functions(m, device):
|
|
if device is None:
|
|
return 0
|
|
|
|
memory = 0
|
|
if hasattr(m, "weight_function"):
|
|
for f in m.weight_function:
|
|
if hasattr(f, "move_to"):
|
|
memory += f.move_to(device=device)
|
|
|
|
if hasattr(m, "bias_function"):
|
|
for f in m.bias_function:
|
|
if hasattr(f, "move_to"):
|
|
memory += f.move_to(device=device)
|
|
return memory
|
|
|
|
class LowVramPatch:
|
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
|
self.key = key
|
|
self.patches = patches
|
|
self.convert_func = convert_func
|
|
self.set_func = set_func
|
|
|
|
def __call__(self, weight):
|
|
intermediate_dtype = weight.dtype
|
|
if self.convert_func is not None:
|
|
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
|
|
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
|
intermediate_dtype = torch.float32
|
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
|
if self.set_func is None:
|
|
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
|
else:
|
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
|
|
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
|
if self.set_func is not None:
|
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
|
else:
|
|
return out
|
|
|
|
def get_key_weight(model, key):
|
|
set_func = None
|
|
convert_func = None
|
|
op_keys = key.rsplit('.', 1)
|
|
if len(op_keys) < 2:
|
|
weight = comfy.utils.get_attr(model, key)
|
|
else:
|
|
op = comfy.utils.get_attr(model, op_keys[0])
|
|
try:
|
|
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
|
except AttributeError:
|
|
pass
|
|
|
|
try:
|
|
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
|
except AttributeError:
|
|
pass
|
|
|
|
weight = getattr(op, op_keys[1])
|
|
if convert_func is not None:
|
|
weight = comfy.utils.get_attr(model, key)
|
|
|
|
return weight, set_func, convert_func
|
|
|
|
class AutoPatcherEjector:
|
|
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
|
self.model = model
|
|
self.was_injected = False
|
|
self.prev_skip_injection = False
|
|
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
|
|
|
def __enter__(self):
|
|
self.was_injected = False
|
|
self.prev_skip_injection = self.model.skip_injection
|
|
if self.skip_and_inject_on_exit_only:
|
|
self.model.skip_injection = True
|
|
if self.model.is_injected:
|
|
self.model.eject_model()
|
|
self.was_injected = True
|
|
|
|
def __exit__(self, *args):
|
|
if self.skip_and_inject_on_exit_only:
|
|
self.model.skip_injection = self.prev_skip_injection
|
|
self.model.inject_model()
|
|
if self.was_injected and not self.model.skip_injection:
|
|
self.model.inject_model()
|
|
self.model.skip_injection = self.prev_skip_injection
|
|
|
|
class MemoryCounter:
|
|
def __init__(self, initial: int, minimum=0):
|
|
self.value = initial
|
|
self.minimum = minimum
|
|
# TODO: add a safe limit besides 0
|
|
|
|
def use(self, weight: torch.Tensor):
|
|
weight_size = weight.nelement() * weight.element_size()
|
|
if self.is_useable(weight_size):
|
|
self.decrement(weight_size)
|
|
return True
|
|
return False
|
|
|
|
def is_useable(self, used: int):
|
|
return self.value - used > self.minimum
|
|
|
|
def decrement(self, used: int):
|
|
self.value -= used
|
|
|
|
class ModelPatcher:
|
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
|
self.size = size
|
|
self.model = model
|
|
if not hasattr(self.model, 'device'):
|
|
logging.debug("Model doesn't have a device attribute.")
|
|
self.model.device = offload_device
|
|
elif self.model.device is None:
|
|
self.model.device = offload_device
|
|
|
|
self.patches = {}
|
|
self.backup = {}
|
|
self.object_patches = {}
|
|
self.object_patches_backup = {}
|
|
self.weight_wrapper_patches = {}
|
|
self.model_options = {"transformer_options":{}}
|
|
self.model_size()
|
|
self.load_device = load_device
|
|
self.offload_device = offload_device
|
|
self.weight_inplace_update = weight_inplace_update
|
|
self.force_cast_weights = False
|
|
self.patches_uuid = uuid.uuid4()
|
|
self.parent = None
|
|
self.pinned = set()
|
|
|
|
self.attachments: dict[str] = {}
|
|
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
|
self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks()
|
|
self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers()
|
|
|
|
self.is_injected = False
|
|
self.skip_injection = False
|
|
self.injections: dict[str, list[PatcherInjection]] = {}
|
|
|
|
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
|
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
|
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
|
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
|
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
|
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time
|
|
self.is_clip = False
|
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
|
|
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
|
self.model.model_loaded_weight_memory = 0
|
|
|
|
if not hasattr(self.model, 'lowvram_patch_counter'):
|
|
self.model.lowvram_patch_counter = 0
|
|
|
|
if not hasattr(self.model, 'model_lowvram'):
|
|
self.model.model_lowvram = False
|
|
|
|
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
|
self.model.current_weight_patches_uuid = None
|
|
|
|
def model_size(self):
|
|
if self.size > 0:
|
|
return self.size
|
|
self.size = comfy.model_management.module_size(self.model)
|
|
return self.size
|
|
|
|
def get_ram_usage(self):
|
|
return self.model_size()
|
|
|
|
def loaded_size(self):
|
|
return self.model.model_loaded_weight_memory
|
|
|
|
def lowvram_patch_counter(self):
|
|
return self.model.lowvram_patch_counter
|
|
|
|
def clone(self):
|
|
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
|
n.patches = {}
|
|
for k in self.patches:
|
|
n.patches[k] = self.patches[k][:]
|
|
n.patches_uuid = self.patches_uuid
|
|
|
|
n.object_patches = self.object_patches.copy()
|
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
|
n.model_options = copy.deepcopy(self.model_options)
|
|
n.backup = self.backup
|
|
n.object_patches_backup = self.object_patches_backup
|
|
n.parent = self
|
|
|
|
n.force_cast_weights = self.force_cast_weights
|
|
|
|
# attachments
|
|
n.attachments = {}
|
|
for k in self.attachments:
|
|
if hasattr(self.attachments[k], "on_model_patcher_clone"):
|
|
n.attachments[k] = self.attachments[k].on_model_patcher_clone()
|
|
else:
|
|
n.attachments[k] = self.attachments[k]
|
|
# additional models
|
|
for k, c in self.additional_models.items():
|
|
n.additional_models[k] = [x.clone() for x in c]
|
|
# callbacks
|
|
for k, c in self.callbacks.items():
|
|
n.callbacks[k] = {}
|
|
for k1, c1 in c.items():
|
|
n.callbacks[k][k1] = c1.copy()
|
|
# sample wrappers
|
|
for k, w in self.wrappers.items():
|
|
n.wrappers[k] = {}
|
|
for k1, w1 in w.items():
|
|
n.wrappers[k][k1] = w1.copy()
|
|
# injection
|
|
n.is_injected = self.is_injected
|
|
n.skip_injection = self.skip_injection
|
|
for k, i in self.injections.items():
|
|
n.injections[k] = i.copy()
|
|
# hooks
|
|
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
|
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
|
for group in self.cached_hook_patches:
|
|
n.cached_hook_patches[group] = {}
|
|
for k in self.cached_hook_patches[group]:
|
|
n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k]
|
|
n.hook_backup = self.hook_backup
|
|
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
|
|
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
|
n.is_clip = self.is_clip
|
|
n.hook_mode = self.hook_mode
|
|
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
|
callback(self, n)
|
|
return n
|
|
|
|
def is_clone(self, other):
|
|
if hasattr(other, 'model') and self.model is other.model:
|
|
return True
|
|
return False
|
|
|
|
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
|
if not self.is_clone(clone):
|
|
return False
|
|
|
|
if self.current_hooks != clone.current_hooks:
|
|
return False
|
|
if self.forced_hooks != clone.forced_hooks:
|
|
return False
|
|
if self.hook_patches.keys() != clone.hook_patches.keys():
|
|
return False
|
|
if self.attachments.keys() != clone.attachments.keys():
|
|
return False
|
|
if self.additional_models.keys() != clone.additional_models.keys():
|
|
return False
|
|
for key in self.callbacks:
|
|
if len(self.callbacks[key]) != len(clone.callbacks[key]):
|
|
return False
|
|
for key in self.wrappers:
|
|
if len(self.wrappers[key]) != len(clone.wrappers[key]):
|
|
return False
|
|
if self.injections.keys() != clone.injections.keys():
|
|
return False
|
|
|
|
if len(self.patches) == 0 and len(clone.patches) == 0:
|
|
return True
|
|
|
|
if self.patches_uuid == clone.patches_uuid:
|
|
if len(self.patches) != len(clone.patches):
|
|
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
|
else:
|
|
return True
|
|
|
|
def memory_required(self, input_shape):
|
|
return self.model.memory_required(input_shape=input_shape)
|
|
|
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
|
else:
|
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
|
if disable_cfg1_optimization:
|
|
self.model_options["disable_cfg1_optimization"] = True
|
|
|
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
|
|
|
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
|
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
|
|
|
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
|
|
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
|
|
|
|
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
|
|
|
def set_model_denoise_mask_function(self, denoise_mask_function):
|
|
self.model_options["denoise_mask_function"] = denoise_mask_function
|
|
|
|
def set_model_patch(self, patch, name):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" not in to:
|
|
to["patches"] = {}
|
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
|
|
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
|
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
|
|
|
def set_model_attn1_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_patch")
|
|
|
|
def set_model_attn2_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_patch")
|
|
|
|
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
|
|
|
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
|
|
|
def set_model_attn1_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_output_patch")
|
|
|
|
def set_model_attn2_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_output_patch")
|
|
|
|
def set_model_input_block_patch(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch")
|
|
|
|
def set_model_input_block_patch_after_skip(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch_after_skip")
|
|
|
|
def set_model_output_block_patch(self, patch):
|
|
self.set_model_patch(patch, "output_block_patch")
|
|
|
|
def set_model_emb_patch(self, patch):
|
|
self.set_model_patch(patch, "emb_patch")
|
|
|
|
def set_model_forward_timestep_embed_patch(self, patch):
|
|
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
|
|
|
def set_model_double_block_patch(self, patch):
|
|
self.set_model_patch(patch, "double_block")
|
|
|
|
def set_model_post_input_patch(self, patch):
|
|
self.set_model_patch(patch, "post_input")
|
|
|
|
def add_object_patch(self, name, obj):
|
|
self.object_patches[name] = obj
|
|
|
|
def set_model_compute_dtype(self, dtype):
|
|
self.add_object_patch("manual_cast_dtype", dtype)
|
|
if dtype is not None:
|
|
self.force_cast_weights = True
|
|
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
|
|
|
def add_weight_wrapper(self, name, function):
|
|
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
|
self.patches_uuid = uuid.uuid4()
|
|
|
|
def get_model_object(self, name: str) -> torch.nn.Module:
|
|
"""Retrieves a nested attribute from an object using dot notation considering
|
|
object patches.
|
|
|
|
Args:
|
|
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
|
|
|
Returns:
|
|
The value of the requested attribute
|
|
|
|
Example:
|
|
patcher = ModelPatcher()
|
|
weight = patcher.get_model_object("layer1.conv.weight")
|
|
"""
|
|
if name in self.object_patches:
|
|
return self.object_patches[name]
|
|
else:
|
|
if name in self.object_patches_backup:
|
|
return self.object_patches_backup[name]
|
|
else:
|
|
return comfy.utils.get_attr(self.model, name)
|
|
|
|
def model_patches_to(self, device):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" in to:
|
|
patches = to["patches"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for i in range(len(patch_list)):
|
|
if hasattr(patch_list[i], "to"):
|
|
patch_list[i] = patch_list[i].to(device)
|
|
if "patches_replace" in to:
|
|
patches = to["patches_replace"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for k in patch_list:
|
|
if hasattr(patch_list[k], "to"):
|
|
patch_list[k] = patch_list[k].to(device)
|
|
if "model_function_wrapper" in self.model_options:
|
|
wrap_func = self.model_options["model_function_wrapper"]
|
|
if hasattr(wrap_func, "to"):
|
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
|
|
|
def model_patches_models(self):
|
|
to = self.model_options["transformer_options"]
|
|
models = []
|
|
if "patches" in to:
|
|
patches = to["patches"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for i in range(len(patch_list)):
|
|
if hasattr(patch_list[i], "models"):
|
|
models += patch_list[i].models()
|
|
if "patches_replace" in to:
|
|
patches = to["patches_replace"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for k in patch_list:
|
|
if hasattr(patch_list[k], "models"):
|
|
models += patch_list[k].models()
|
|
if "model_function_wrapper" in self.model_options:
|
|
wrap_func = self.model_options["model_function_wrapper"]
|
|
if hasattr(wrap_func, "models"):
|
|
models += wrap_func.models()
|
|
|
|
return models
|
|
|
|
def model_dtype(self):
|
|
if hasattr(self.model, "get_dtype"):
|
|
return self.model.get_dtype()
|
|
|
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
with self.use_ejected():
|
|
p = set()
|
|
model_sd = self.model.state_dict()
|
|
for k in patches:
|
|
offset = None
|
|
function = None
|
|
if isinstance(k, str):
|
|
key = k
|
|
else:
|
|
offset = k[1]
|
|
key = k[0]
|
|
if len(k) > 2:
|
|
function = k[2]
|
|
|
|
if key in model_sd:
|
|
p.add(k)
|
|
current_patches = self.patches.get(key, [])
|
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
|
self.patches[key] = current_patches
|
|
|
|
self.patches_uuid = uuid.uuid4()
|
|
return list(p)
|
|
|
|
def get_key_patches(self, filter_prefix=None):
|
|
model_sd = self.model_state_dict()
|
|
p = {}
|
|
for k in model_sd:
|
|
if filter_prefix is not None:
|
|
if not k.startswith(filter_prefix):
|
|
continue
|
|
bk = self.backup.get(k, None)
|
|
hbk = self.hook_backup.get(k, None)
|
|
weight, set_func, convert_func = get_key_weight(self.model, k)
|
|
if bk is not None:
|
|
weight = bk.weight
|
|
if hbk is not None:
|
|
weight = hbk[0]
|
|
if convert_func is None:
|
|
convert_func = lambda a, **kwargs: a
|
|
|
|
if k in self.patches:
|
|
p[k] = [(weight, convert_func)] + self.patches[k]
|
|
else:
|
|
p[k] = [(weight, convert_func)]
|
|
return p
|
|
|
|
def model_state_dict(self, filter_prefix=None):
|
|
with self.use_ejected():
|
|
sd = self.model.state_dict()
|
|
keys = list(sd.keys())
|
|
if filter_prefix is not None:
|
|
for k in keys:
|
|
if not k.startswith(filter_prefix):
|
|
sd.pop(k)
|
|
return sd
|
|
|
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
|
if key not in self.patches:
|
|
return
|
|
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
inplace_update = self.weight_inplace_update or inplace_update
|
|
|
|
if key not in self.backup:
|
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
|
|
|
if device_to is not None:
|
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
|
else:
|
|
temp_weight = weight.to(torch.float32, copy=True)
|
|
if convert_func is not None:
|
|
temp_weight = convert_func(temp_weight, inplace=True)
|
|
|
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
|
if set_func is None:
|
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
|
if inplace_update:
|
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
|
else:
|
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
|
else:
|
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
|
|
|
def pin_weight_to_device(self, key):
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
if comfy.model_management.pin_memory(weight):
|
|
self.pinned.add(key)
|
|
|
|
def unpin_weight(self, key):
|
|
if key in self.pinned:
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
comfy.model_management.unpin_memory(weight)
|
|
self.pinned.remove(key)
|
|
|
|
def unpin_all_weights(self):
|
|
for key in list(self.pinned):
|
|
self.unpin_weight(key)
|
|
|
|
def _load_list(self):
|
|
loading = []
|
|
for n, m in self.model.named_modules():
|
|
params = []
|
|
skip = False
|
|
for name, param in m.named_parameters(recurse=False):
|
|
params.append(name)
|
|
for name, param in m.named_parameters(recurse=True):
|
|
if name not in params:
|
|
skip = True # skip random weights in non leaf modules
|
|
break
|
|
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
|
loading.append((comfy.model_management.module_size(m), n, m, params))
|
|
return loading
|
|
|
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
|
with self.use_ejected():
|
|
self.unpatch_hooks()
|
|
mem_counter = 0
|
|
patch_counter = 0
|
|
lowvram_counter = 0
|
|
lowvram_mem_counter = 0
|
|
loading = self._load_list()
|
|
|
|
load_completely = []
|
|
offloaded = []
|
|
loading.sort(reverse=True)
|
|
for x in loading:
|
|
n = x[1]
|
|
m = x[2]
|
|
params = x[3]
|
|
module_mem = x[0]
|
|
|
|
lowvram_weight = False
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
bias_key = "{}.bias".format(n)
|
|
|
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
|
if mem_counter + module_mem >= lowvram_model_memory:
|
|
lowvram_weight = True
|
|
lowvram_counter += 1
|
|
lowvram_mem_counter += module_mem
|
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
|
continue
|
|
|
|
cast_weight = self.force_cast_weights
|
|
if lowvram_weight:
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
m.weight_function = []
|
|
m.bias_function = []
|
|
|
|
if weight_key in self.patches:
|
|
if force_patch_weights:
|
|
self.patch_weight_to_device(weight_key)
|
|
else:
|
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
|
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
|
|
patch_counter += 1
|
|
if bias_key in self.patches:
|
|
if force_patch_weights:
|
|
self.patch_weight_to_device(bias_key)
|
|
else:
|
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
|
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
|
|
patch_counter += 1
|
|
|
|
cast_weight = True
|
|
offloaded.append((module_mem, n, m, params))
|
|
else:
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
wipe_lowvram_weight(m)
|
|
|
|
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
|
mem_counter += module_mem
|
|
load_completely.append((module_mem, n, m, params))
|
|
|
|
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
|
m.comfy_cast_weights = True
|
|
|
|
if weight_key in self.weight_wrapper_patches:
|
|
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
|
|
|
if bias_key in self.weight_wrapper_patches:
|
|
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
|
|
|
|
mem_counter += move_weight_functions(m, device_to)
|
|
|
|
load_completely.sort(reverse=True)
|
|
for x in load_completely:
|
|
n = x[1]
|
|
m = x[2]
|
|
params = x[3]
|
|
if hasattr(m, "comfy_patched_weights"):
|
|
if m.comfy_patched_weights == True:
|
|
continue
|
|
|
|
for param in params:
|
|
key = "{}.{}".format(n, param)
|
|
self.unpin_weight(key)
|
|
self.patch_weight_to_device(key, device_to=device_to)
|
|
|
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
|
m.comfy_patched_weights = True
|
|
|
|
for x in load_completely:
|
|
x[2].to(device_to)
|
|
|
|
for x in offloaded:
|
|
n = x[1]
|
|
params = x[3]
|
|
for param in params:
|
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
|
|
|
if lowvram_counter > 0:
|
|
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
|
self.model.model_lowvram = True
|
|
else:
|
|
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
|
self.model.model_lowvram = False
|
|
if full_load:
|
|
self.model.to(device_to)
|
|
mem_counter = self.model_size()
|
|
|
|
self.model.lowvram_patch_counter += patch_counter
|
|
self.model.device = device_to
|
|
self.model.model_loaded_weight_memory = mem_counter
|
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
|
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
|
|
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
|
|
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
|
with self.use_ejected():
|
|
for k in self.object_patches:
|
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
|
if k not in self.object_patches_backup:
|
|
self.object_patches_backup[k] = old
|
|
|
|
if lowvram_model_memory == 0:
|
|
full_load = True
|
|
else:
|
|
full_load = False
|
|
|
|
if load_weights:
|
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
|
self.inject_model()
|
|
return self.model
|
|
|
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
|
self.eject_model()
|
|
if unpatch_weights:
|
|
self.unpatch_hooks()
|
|
self.unpin_all_weights()
|
|
if self.model.model_lowvram:
|
|
for m in self.model.modules():
|
|
move_weight_functions(m, device_to)
|
|
wipe_lowvram_weight(m)
|
|
|
|
self.model.model_lowvram = False
|
|
self.model.lowvram_patch_counter = 0
|
|
|
|
keys = list(self.backup.keys())
|
|
|
|
for k in keys:
|
|
bk = self.backup[k]
|
|
if bk.inplace_update:
|
|
comfy.utils.copy_to_param(self.model, k, bk.weight)
|
|
else:
|
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
|
|
|
self.model.current_weight_patches_uuid = None
|
|
self.backup.clear()
|
|
|
|
if device_to is not None:
|
|
self.model.to(device_to)
|
|
self.model.device = device_to
|
|
self.model.model_loaded_weight_memory = 0
|
|
|
|
for m in self.model.modules():
|
|
if hasattr(m, "comfy_patched_weights"):
|
|
del m.comfy_patched_weights
|
|
|
|
keys = list(self.object_patches_backup.keys())
|
|
for k in keys:
|
|
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
|
|
|
self.object_patches_backup.clear()
|
|
|
|
def partially_unload(self, device_to, memory_to_free=0):
|
|
with self.use_ejected():
|
|
hooks_unpatched = False
|
|
memory_freed = 0
|
|
patch_counter = 0
|
|
unload_list = self._load_list()
|
|
unload_list.sort()
|
|
for unload in unload_list:
|
|
if memory_to_free < memory_freed:
|
|
break
|
|
module_mem = unload[0]
|
|
n = unload[1]
|
|
m = unload[2]
|
|
params = unload[3]
|
|
|
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
|
move_weight = True
|
|
for param in params:
|
|
key = "{}.{}".format(n, param)
|
|
bk = self.backup.get(key, None)
|
|
if bk is not None:
|
|
if not lowvram_possible:
|
|
move_weight = False
|
|
break
|
|
|
|
if not hooks_unpatched:
|
|
self.unpatch_hooks()
|
|
hooks_unpatched = True
|
|
|
|
if bk.inplace_update:
|
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
|
else:
|
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
|
self.backup.pop(key)
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
bias_key = "{}.bias".format(n)
|
|
if move_weight:
|
|
cast_weight = self.force_cast_weights
|
|
m.to(device_to)
|
|
module_mem += move_weight_functions(m, device_to)
|
|
if lowvram_possible:
|
|
if weight_key in self.patches:
|
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
|
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
|
patch_counter += 1
|
|
if bias_key in self.patches:
|
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
|
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
|
patch_counter += 1
|
|
cast_weight = True
|
|
|
|
if cast_weight:
|
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
|
m.comfy_cast_weights = True
|
|
m.comfy_patched_weights = False
|
|
memory_freed += module_mem
|
|
logging.debug("freed {}".format(n))
|
|
|
|
for param in params:
|
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
|
|
|
self.model.model_lowvram = True
|
|
self.model.lowvram_patch_counter += patch_counter
|
|
self.model.model_loaded_weight_memory -= memory_freed
|
|
return memory_freed
|
|
|
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
|
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
|
# TODO: force_patch_weights should not unload + reload full model
|
|
used = self.model.model_loaded_weight_memory
|
|
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
|
if unpatch_weights:
|
|
extra_memory += (used - self.model.model_loaded_weight_memory)
|
|
|
|
self.patch_model(load_weights=False)
|
|
full_load = False
|
|
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
|
return 0
|
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
|
full_load = True
|
|
current_used = self.model.model_loaded_weight_memory
|
|
try:
|
|
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
|
except Exception as e:
|
|
self.detach()
|
|
raise e
|
|
|
|
return self.model.model_loaded_weight_memory - current_used
|
|
|
|
def detach(self, unpatch_all=True):
|
|
self.eject_model()
|
|
self.model_patches_to(self.offload_device)
|
|
if unpatch_all:
|
|
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
|
callback(self, unpatch_all)
|
|
return self.model
|
|
|
|
def current_loaded_device(self):
|
|
return self.model.device
|
|
|
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
|
logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
|
|
|
def cleanup(self):
|
|
self.clean_hooks()
|
|
if hasattr(self.model, "current_patcher"):
|
|
self.model.current_patcher = None
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
|
|
callback(self)
|
|
|
|
def add_callback(self, call_type: str, callback: Callable):
|
|
self.add_callback_with_key(call_type, None, callback)
|
|
|
|
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
|
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
|
c.append(callback)
|
|
|
|
def remove_callbacks_with_key(self, call_type: str, key: str):
|
|
c = self.callbacks.get(call_type, {})
|
|
if key in c:
|
|
c.pop(key)
|
|
|
|
def get_callbacks(self, call_type: str, key: str):
|
|
return self.callbacks.get(call_type, {}).get(key, [])
|
|
|
|
def get_all_callbacks(self, call_type: str):
|
|
c_list = []
|
|
for c in self.callbacks.get(call_type, {}).values():
|
|
c_list.extend(c)
|
|
return c_list
|
|
|
|
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
|
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
|
|
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
|
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
|
w.append(wrapper)
|
|
|
|
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
|
w = self.wrappers.get(wrapper_type, {})
|
|
if key in w:
|
|
w.pop(key)
|
|
|
|
def get_wrappers(self, wrapper_type: str, key: str):
|
|
return self.wrappers.get(wrapper_type, {}).get(key, [])
|
|
|
|
def get_all_wrappers(self, wrapper_type: str):
|
|
w_list = []
|
|
for w in self.wrappers.get(wrapper_type, {}).values():
|
|
w_list.extend(w)
|
|
return w_list
|
|
|
|
def set_attachments(self, key: str, attachment):
|
|
self.attachments[key] = attachment
|
|
|
|
def remove_attachments(self, key: str):
|
|
if key in self.attachments:
|
|
self.attachments.pop(key)
|
|
|
|
def get_attachment(self, key: str):
|
|
return self.attachments.get(key, None)
|
|
|
|
def set_injections(self, key: str, injections: list[PatcherInjection]):
|
|
self.injections[key] = injections
|
|
|
|
def remove_injections(self, key: str):
|
|
if key in self.injections:
|
|
self.injections.pop(key)
|
|
|
|
def get_injections(self, key: str):
|
|
return self.injections.get(key, None)
|
|
|
|
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
|
self.additional_models[key] = models
|
|
|
|
def remove_additional_models(self, key: str):
|
|
if key in self.additional_models:
|
|
self.additional_models.pop(key)
|
|
|
|
def get_additional_models_with_key(self, key: str):
|
|
return self.additional_models.get(key, [])
|
|
|
|
def get_additional_models(self):
|
|
all_models = []
|
|
for models in self.additional_models.values():
|
|
all_models.extend(models)
|
|
return all_models
|
|
|
|
def get_nested_additional_models(self):
|
|
def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]):
|
|
'''Make sure circular references do not cause infinite recursion.'''
|
|
next_models = []
|
|
for model in prev_models:
|
|
candidates = model.get_additional_models()
|
|
for c in candidates:
|
|
if c not in cache_set:
|
|
next_models.append(c)
|
|
cache_set.add(c)
|
|
if len(next_models) == 0:
|
|
return prev_models
|
|
return prev_models + _evaluate_sub_additional_models(next_models, cache_set)
|
|
|
|
all_models = self.get_additional_models()
|
|
models_set = set(all_models)
|
|
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
|
|
return real_all_models
|
|
|
|
def use_ejected(self, skip_and_inject_on_exit_only=False):
|
|
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)
|
|
|
|
def inject_model(self):
|
|
if self.is_injected or self.skip_injection:
|
|
return
|
|
for injections in self.injections.values():
|
|
for inj in injections:
|
|
inj.inject(self)
|
|
self.is_injected = True
|
|
if self.is_injected:
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL):
|
|
callback(self)
|
|
|
|
def eject_model(self):
|
|
if not self.is_injected:
|
|
return
|
|
for injections in self.injections.values():
|
|
for inj in injections:
|
|
inj.eject(self)
|
|
self.is_injected = False
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL):
|
|
callback(self)
|
|
|
|
def pre_run(self):
|
|
if hasattr(self.model, "current_patcher"):
|
|
self.model.current_patcher = self
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
|
callback(self)
|
|
|
|
def prepare_state(self, timestep):
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
|
callback(self, timestep)
|
|
|
|
def restore_hook_patches(self):
|
|
if self.hook_patches_backup is not None:
|
|
self.hook_patches = self.hook_patches_backup
|
|
self.hook_patches_backup = None
|
|
|
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
|
self.hook_mode = hook_mode
|
|
|
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
|
curr_t = t[0]
|
|
reset_current_hooks = False
|
|
transformer_options = model_options.get("transformer_options", {})
|
|
for hook in hook_group.hooks:
|
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
|
# this will cause the weights to be recalculated when sampling
|
|
if changed:
|
|
# reset current_hooks if contains hook that changed
|
|
if self.current_hooks is not None:
|
|
for current_hook in self.current_hooks.hooks:
|
|
if current_hook == hook:
|
|
reset_current_hooks = True
|
|
break
|
|
for cached_group in list(self.cached_hook_patches.keys()):
|
|
if cached_group.contains(hook):
|
|
self.cached_hook_patches.pop(cached_group)
|
|
if reset_current_hooks:
|
|
self.patch_hooks(None)
|
|
|
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
|
registered: comfy.hooks.HookGroup = None):
|
|
self.restore_hook_patches()
|
|
if registered is None:
|
|
registered = comfy.hooks.HookGroup()
|
|
# handle WeightHooks
|
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
|
if hook.hook_ref not in self.hook_patches:
|
|
weight_hooks_to_register.append(hook)
|
|
else:
|
|
registered.add(hook)
|
|
if len(weight_hooks_to_register) > 0:
|
|
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
|
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
|
for hook in weight_hooks_to_register:
|
|
hook.add_hook_patches(self, model_options, target_dict, registered)
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
|
callback(self, hooks, target_dict, model_options, registered)
|
|
return registered
|
|
|
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
|
with self.use_ejected():
|
|
# NOTE: this mirrors behavior of add_patches func
|
|
current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {})
|
|
p = set()
|
|
model_sd = self.model.state_dict()
|
|
for k in patches:
|
|
offset = None
|
|
function = None
|
|
if isinstance(k, str):
|
|
key = k
|
|
else:
|
|
offset = k[1]
|
|
key = k[0]
|
|
if len(k) > 2:
|
|
function = k[2]
|
|
|
|
if key in model_sd:
|
|
p.add(k)
|
|
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
|
current_hook_patches[key] = current_patches
|
|
self.hook_patches[hook.hook_ref] = current_hook_patches
|
|
# since should care about these patches too to determine if same model, reroll patches_uuid
|
|
self.patches_uuid = uuid.uuid4()
|
|
return list(p)
|
|
|
|
def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
|
|
# combined_patches will contain weights of all relevant hooks, per key
|
|
combined_patches = {}
|
|
if hooks is not None:
|
|
for hook in hooks.hooks:
|
|
hook_patches: dict = self.hook_patches.get(hook.hook_ref, {})
|
|
for key in hook_patches.keys():
|
|
current_patches: list[tuple] = combined_patches.get(key, [])
|
|
if math.isclose(hook.strength, 1.0):
|
|
current_patches.extend(hook_patches[key])
|
|
else:
|
|
# patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model)
|
|
for patch in hook_patches[key]:
|
|
new_patch = list(patch)
|
|
new_patch[0] *= hook.strength
|
|
current_patches.append(tuple(new_patch))
|
|
combined_patches[key] = current_patches
|
|
return combined_patches
|
|
|
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
|
# TODO: return transformer_options dict with any additions from hooks
|
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
|
self.patch_hooks(hooks=hooks)
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
|
callback(self, hooks)
|
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
|
|
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
|
with self.use_ejected():
|
|
if hooks is not None:
|
|
model_sd_keys = list(self.model_state_dict().keys())
|
|
memory_counter = None
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
|
|
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
|
|
minimum=comfy.model_management.minimum_inference_memory()*2)
|
|
# if have cached weights for hooks, use it
|
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
|
if cached_weights is not None:
|
|
model_sd_keys_set = set(model_sd_keys)
|
|
for key in cached_weights:
|
|
if key not in model_sd_keys:
|
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
|
continue
|
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
|
model_sd_keys_set.remove(key)
|
|
self.unpatch_hooks(model_sd_keys_set)
|
|
else:
|
|
self.unpatch_hooks()
|
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
|
original_weights = None
|
|
if len(relevant_patches) > 0:
|
|
original_weights = self.get_key_patches()
|
|
for key in relevant_patches:
|
|
if key not in model_sd_keys:
|
|
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
|
continue
|
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
|
memory_counter=memory_counter)
|
|
else:
|
|
self.unpatch_hooks()
|
|
self.current_hooks = hooks
|
|
|
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
|
if key not in self.hook_backup:
|
|
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
|
|
target_device = self.offload_device
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
used = memory_counter.use(weight)
|
|
if used:
|
|
target_device = weight.device
|
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
|
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
|
|
|
def clear_cached_hook_weights(self):
|
|
self.cached_hook_patches.clear()
|
|
self.patch_hooks(None)
|
|
|
|
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
|
if key not in combined_patches:
|
|
return
|
|
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
weight: torch.Tensor
|
|
if key not in self.hook_backup:
|
|
target_device = self.offload_device
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
used = memory_counter.use(weight)
|
|
if used:
|
|
target_device = weight.device
|
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
|
# TODO: properly handle LowVramPatch, if it ends up an issue
|
|
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
|
|
if convert_func is not None:
|
|
temp_weight = convert_func(temp_weight, inplace=True)
|
|
|
|
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
|
temp_weight,
|
|
key, original_weights=original_weights)
|
|
del original_weights[key]
|
|
if set_func is None:
|
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
|
else:
|
|
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
# TODO: disable caching if not enough system RAM to do so
|
|
target_device = self.offload_device
|
|
used = memory_counter.use(weight)
|
|
if used:
|
|
target_device = weight.device
|
|
self.cached_hook_patches.setdefault(hooks, {})
|
|
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
|
|
del temp_weight
|
|
del out_weight
|
|
del weight
|
|
|
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
|
with self.use_ejected():
|
|
if len(self.hook_backup) == 0:
|
|
self.current_hooks = None
|
|
return
|
|
keys = list(self.hook_backup.keys())
|
|
if whitelist_keys_set:
|
|
for k in keys:
|
|
if k in whitelist_keys_set:
|
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
|
self.hook_backup.pop(k)
|
|
else:
|
|
for k in keys:
|
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
|
|
|
self.hook_backup.clear()
|
|
self.current_hooks = None
|
|
|
|
def clean_hooks(self):
|
|
self.unpatch_hooks()
|
|
self.clear_cached_hook_weights()
|
|
|
|
def __del__(self):
|
|
self.unpin_all_weights()
|
|
self.detach(unpatch_all=False)
|
|
|