mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
* mm: default to 0 for NUM_STREAMS Dont count the compute stream as an offload stream. This makes async offload accounting easier. * mm: remove 128MB minimum This is from a previous offloading system requirement. Remove it to make behaviour of the loader and partial unloader consistent. * mp: order the module list by offload expense Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting. * mp: Account for the VRAM cost of weight offloading when checking the VRAM headroom, assume that the weight needs to be offloaded, and only load if it has space for both the load and offload * the number of streams. As the weights are ordered from largest to smallest by offload cost this is guaranteed to fit in VRAM (tm), as all weights that follow will be smaller. Make the partial unload aware of this system as well by saving the budget for offload VRAM to the model state and accounting accordingly. Its possible that partial unload increases the size of the largest offloaded weights, and thus needs to unload a little bit more than asked to accomodate the bigger temp buffers. Honor the existing codes floor on model weight loading of 128MB by having the patcher honor this separately withough regard to offloading. Otherwise when MM specifies its 128MB minimum, MP will see the biggest weights, and budget that 128MB to only offload buffer and load nothing which isnt the intent of these minimums. The same clamp applies in case of partial offload of the currently loading model.
1354 lines
58 KiB
Python
1354 lines
58 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import copy
|
|
import inspect
|
|
import logging
|
|
import math
|
|
import uuid
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
|
|
import comfy.float
|
|
import comfy.hooks
|
|
import comfy.lora
|
|
import comfy.model_management
|
|
import comfy.patcher_extension
|
|
import comfy.utils
|
|
from comfy.comfy_types import UnetWrapperFunction
|
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
|
|
|
|
|
def string_to_seed(data):
|
|
crc = 0xFFFFFFFF
|
|
for byte in data:
|
|
if isinstance(byte, str):
|
|
byte = ord(byte)
|
|
crc ^= byte
|
|
for _ in range(8):
|
|
if crc & 1:
|
|
crc = (crc >> 1) ^ 0xEDB88320
|
|
else:
|
|
crc >>= 1
|
|
return crc ^ 0xFFFFFFFF
|
|
|
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
|
to = model_options["transformer_options"].copy()
|
|
|
|
if "patches_replace" not in to:
|
|
to["patches_replace"] = {}
|
|
else:
|
|
to["patches_replace"] = to["patches_replace"].copy()
|
|
|
|
if name not in to["patches_replace"]:
|
|
to["patches_replace"][name] = {}
|
|
else:
|
|
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
|
|
|
if transformer_index is not None:
|
|
block = (block_name, number, transformer_index)
|
|
else:
|
|
block = (block_name, number)
|
|
to["patches_replace"][name][block] = patch
|
|
model_options["transformer_options"] = to
|
|
return model_options
|
|
|
|
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
|
|
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
|
if disable_cfg1_optimization:
|
|
model_options["disable_cfg1_optimization"] = True
|
|
return model_options
|
|
|
|
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
|
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
|
if disable_cfg1_optimization:
|
|
model_options["disable_cfg1_optimization"] = True
|
|
return model_options
|
|
|
|
def create_model_options_clone(orig_model_options: dict):
|
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
|
|
|
def create_hook_patches_clone(orig_hook_patches):
|
|
new_hook_patches = {}
|
|
for hook_ref in orig_hook_patches:
|
|
new_hook_patches[hook_ref] = {}
|
|
for k in orig_hook_patches[hook_ref]:
|
|
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
|
return new_hook_patches
|
|
|
|
def wipe_lowvram_weight(m):
|
|
if hasattr(m, "prev_comfy_cast_weights"):
|
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
|
del m.prev_comfy_cast_weights
|
|
|
|
if hasattr(m, "weight_function"):
|
|
m.weight_function = []
|
|
|
|
if hasattr(m, "bias_function"):
|
|
m.bias_function = []
|
|
|
|
def move_weight_functions(m, device):
|
|
if device is None:
|
|
return 0
|
|
|
|
memory = 0
|
|
if hasattr(m, "weight_function"):
|
|
for f in m.weight_function:
|
|
if hasattr(f, "move_to"):
|
|
memory += f.move_to(device=device)
|
|
|
|
if hasattr(m, "bias_function"):
|
|
for f in m.bias_function:
|
|
if hasattr(f, "move_to"):
|
|
memory += f.move_to(device=device)
|
|
return memory
|
|
|
|
class LowVramPatch:
|
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
|
self.key = key
|
|
self.patches = patches
|
|
self.convert_func = convert_func
|
|
self.set_func = set_func
|
|
|
|
def __call__(self, weight):
|
|
intermediate_dtype = weight.dtype
|
|
if self.convert_func is not None:
|
|
weight = self.convert_func(weight, inplace=False)
|
|
|
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
|
intermediate_dtype = torch.float32
|
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
|
if self.set_func is None:
|
|
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
|
else:
|
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
|
|
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
|
if self.set_func is not None:
|
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
|
else:
|
|
return out
|
|
|
|
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
|
|
|
def low_vram_patch_estimate_vram(model, key):
|
|
weight, set_func, convert_func = get_key_weight(model, key)
|
|
if weight is None:
|
|
return 0
|
|
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
|
|
|
def get_key_weight(model, key):
|
|
set_func = None
|
|
convert_func = None
|
|
op_keys = key.rsplit('.', 1)
|
|
if len(op_keys) < 2:
|
|
weight = comfy.utils.get_attr(model, key)
|
|
else:
|
|
op = comfy.utils.get_attr(model, op_keys[0])
|
|
try:
|
|
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
|
except AttributeError:
|
|
pass
|
|
|
|
try:
|
|
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
|
except AttributeError:
|
|
pass
|
|
|
|
weight = getattr(op, op_keys[1])
|
|
if convert_func is not None:
|
|
weight = comfy.utils.get_attr(model, key)
|
|
|
|
return weight, set_func, convert_func
|
|
|
|
class AutoPatcherEjector:
|
|
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
|
self.model = model
|
|
self.was_injected = False
|
|
self.prev_skip_injection = False
|
|
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
|
|
|
def __enter__(self):
|
|
self.was_injected = False
|
|
self.prev_skip_injection = self.model.skip_injection
|
|
if self.skip_and_inject_on_exit_only:
|
|
self.model.skip_injection = True
|
|
if self.model.is_injected:
|
|
self.model.eject_model()
|
|
self.was_injected = True
|
|
|
|
def __exit__(self, *args):
|
|
if self.skip_and_inject_on_exit_only:
|
|
self.model.skip_injection = self.prev_skip_injection
|
|
self.model.inject_model()
|
|
if self.was_injected and not self.model.skip_injection:
|
|
self.model.inject_model()
|
|
self.model.skip_injection = self.prev_skip_injection
|
|
|
|
class MemoryCounter:
|
|
def __init__(self, initial: int, minimum=0):
|
|
self.value = initial
|
|
self.minimum = minimum
|
|
# TODO: add a safe limit besides 0
|
|
|
|
def use(self, weight: torch.Tensor):
|
|
weight_size = weight.nelement() * weight.element_size()
|
|
if self.is_useable(weight_size):
|
|
self.decrement(weight_size)
|
|
return True
|
|
return False
|
|
|
|
def is_useable(self, used: int):
|
|
return self.value - used > self.minimum
|
|
|
|
def decrement(self, used: int):
|
|
self.value -= used
|
|
|
|
class ModelPatcher:
|
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
|
self.size = size
|
|
self.model = model
|
|
if not hasattr(self.model, 'device'):
|
|
logging.debug("Model doesn't have a device attribute.")
|
|
self.model.device = offload_device
|
|
elif self.model.device is None:
|
|
self.model.device = offload_device
|
|
|
|
self.patches = {}
|
|
self.backup = {}
|
|
self.object_patches = {}
|
|
self.object_patches_backup = {}
|
|
self.weight_wrapper_patches = {}
|
|
self.model_options = {"transformer_options":{}}
|
|
self.load_device = load_device
|
|
self.offload_device = offload_device
|
|
self.weight_inplace_update = weight_inplace_update
|
|
self.force_cast_weights = False
|
|
self.patches_uuid = uuid.uuid4()
|
|
self.parent = None
|
|
self.pinned = set()
|
|
|
|
self.attachments: dict[str] = {}
|
|
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
|
self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks()
|
|
self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers()
|
|
|
|
self.is_injected = False
|
|
self.skip_injection = False
|
|
self.injections: dict[str, list[PatcherInjection]] = {}
|
|
|
|
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
|
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
|
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
|
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
|
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
|
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time
|
|
self.is_clip = False
|
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
|
|
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
|
self.model.model_loaded_weight_memory = 0
|
|
|
|
if not hasattr(self.model, 'lowvram_patch_counter'):
|
|
self.model.lowvram_patch_counter = 0
|
|
|
|
if not hasattr(self.model, 'model_lowvram'):
|
|
self.model.model_lowvram = False
|
|
|
|
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
|
self.model.current_weight_patches_uuid = None
|
|
|
|
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
|
self.model.model_offload_buffer_memory = 0
|
|
|
|
def model_size(self):
|
|
if self.size > 0:
|
|
return self.size
|
|
self.size = comfy.model_management.module_size(self.model)
|
|
return self.size
|
|
|
|
def get_ram_usage(self):
|
|
return self.model_size()
|
|
|
|
def loaded_size(self):
|
|
return self.model.model_loaded_weight_memory
|
|
|
|
def lowvram_patch_counter(self):
|
|
return self.model.lowvram_patch_counter
|
|
|
|
def clone(self):
|
|
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
|
n.patches = {}
|
|
for k in self.patches:
|
|
n.patches[k] = self.patches[k][:]
|
|
n.patches_uuid = self.patches_uuid
|
|
|
|
n.object_patches = self.object_patches.copy()
|
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
|
n.model_options = copy.deepcopy(self.model_options)
|
|
n.backup = self.backup
|
|
n.object_patches_backup = self.object_patches_backup
|
|
n.parent = self
|
|
n.pinned = self.pinned
|
|
|
|
n.force_cast_weights = self.force_cast_weights
|
|
|
|
# attachments
|
|
n.attachments = {}
|
|
for k in self.attachments:
|
|
if hasattr(self.attachments[k], "on_model_patcher_clone"):
|
|
n.attachments[k] = self.attachments[k].on_model_patcher_clone()
|
|
else:
|
|
n.attachments[k] = self.attachments[k]
|
|
# additional models
|
|
for k, c in self.additional_models.items():
|
|
n.additional_models[k] = [x.clone() for x in c]
|
|
# callbacks
|
|
for k, c in self.callbacks.items():
|
|
n.callbacks[k] = {}
|
|
for k1, c1 in c.items():
|
|
n.callbacks[k][k1] = c1.copy()
|
|
# sample wrappers
|
|
for k, w in self.wrappers.items():
|
|
n.wrappers[k] = {}
|
|
for k1, w1 in w.items():
|
|
n.wrappers[k][k1] = w1.copy()
|
|
# injection
|
|
n.is_injected = self.is_injected
|
|
n.skip_injection = self.skip_injection
|
|
for k, i in self.injections.items():
|
|
n.injections[k] = i.copy()
|
|
# hooks
|
|
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
|
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
|
for group in self.cached_hook_patches:
|
|
n.cached_hook_patches[group] = {}
|
|
for k in self.cached_hook_patches[group]:
|
|
n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k]
|
|
n.hook_backup = self.hook_backup
|
|
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
|
|
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
|
n.is_clip = self.is_clip
|
|
n.hook_mode = self.hook_mode
|
|
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
|
callback(self, n)
|
|
return n
|
|
|
|
def is_clone(self, other):
|
|
if hasattr(other, 'model') and self.model is other.model:
|
|
return True
|
|
return False
|
|
|
|
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
|
if not self.is_clone(clone):
|
|
return False
|
|
|
|
if self.current_hooks != clone.current_hooks:
|
|
return False
|
|
if self.forced_hooks != clone.forced_hooks:
|
|
return False
|
|
if self.hook_patches.keys() != clone.hook_patches.keys():
|
|
return False
|
|
if self.attachments.keys() != clone.attachments.keys():
|
|
return False
|
|
if self.additional_models.keys() != clone.additional_models.keys():
|
|
return False
|
|
for key in self.callbacks:
|
|
if len(self.callbacks[key]) != len(clone.callbacks[key]):
|
|
return False
|
|
for key in self.wrappers:
|
|
if len(self.wrappers[key]) != len(clone.wrappers[key]):
|
|
return False
|
|
if self.injections.keys() != clone.injections.keys():
|
|
return False
|
|
|
|
if len(self.patches) == 0 and len(clone.patches) == 0:
|
|
return True
|
|
|
|
if self.patches_uuid == clone.patches_uuid:
|
|
if len(self.patches) != len(clone.patches):
|
|
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
|
else:
|
|
return True
|
|
|
|
def memory_required(self, input_shape):
|
|
return self.model.memory_required(input_shape=input_shape)
|
|
|
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
|
else:
|
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
|
if disable_cfg1_optimization:
|
|
self.model_options["disable_cfg1_optimization"] = True
|
|
|
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
|
|
|
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
|
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
|
|
|
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
|
|
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
|
|
|
|
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
|
|
|
def set_model_denoise_mask_function(self, denoise_mask_function):
|
|
self.model_options["denoise_mask_function"] = denoise_mask_function
|
|
|
|
def set_model_patch(self, patch, name):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" not in to:
|
|
to["patches"] = {}
|
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
|
|
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
|
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
|
|
|
def set_model_attn1_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_patch")
|
|
|
|
def set_model_attn2_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_patch")
|
|
|
|
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
|
|
|
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
|
|
|
def set_model_attn1_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_output_patch")
|
|
|
|
def set_model_attn2_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_output_patch")
|
|
|
|
def set_model_input_block_patch(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch")
|
|
|
|
def set_model_input_block_patch_after_skip(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch_after_skip")
|
|
|
|
def set_model_output_block_patch(self, patch):
|
|
self.set_model_patch(patch, "output_block_patch")
|
|
|
|
def set_model_emb_patch(self, patch):
|
|
self.set_model_patch(patch, "emb_patch")
|
|
|
|
def set_model_forward_timestep_embed_patch(self, patch):
|
|
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
|
|
|
def set_model_double_block_patch(self, patch):
|
|
self.set_model_patch(patch, "double_block")
|
|
|
|
def set_model_post_input_patch(self, patch):
|
|
self.set_model_patch(patch, "post_input")
|
|
|
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
|
rope_options["scale_x"] = scale_x
|
|
rope_options["scale_y"] = scale_y
|
|
rope_options["scale_t"] = scale_t
|
|
|
|
rope_options["shift_x"] = shift_x
|
|
rope_options["shift_y"] = shift_y
|
|
rope_options["shift_t"] = shift_t
|
|
|
|
self.model_options["transformer_options"]["rope_options"] = rope_options
|
|
|
|
|
|
def add_object_patch(self, name, obj):
|
|
self.object_patches[name] = obj
|
|
|
|
def set_model_compute_dtype(self, dtype):
|
|
self.add_object_patch("manual_cast_dtype", dtype)
|
|
if dtype is not None:
|
|
self.force_cast_weights = True
|
|
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
|
|
|
def add_weight_wrapper(self, name, function):
|
|
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
|
self.patches_uuid = uuid.uuid4()
|
|
|
|
def get_model_object(self, name: str) -> torch.nn.Module:
|
|
"""Retrieves a nested attribute from an object using dot notation considering
|
|
object patches.
|
|
|
|
Args:
|
|
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
|
|
|
Returns:
|
|
The value of the requested attribute
|
|
|
|
Example:
|
|
patcher = ModelPatcher()
|
|
weight = patcher.get_model_object("layer1.conv.weight")
|
|
"""
|
|
if name in self.object_patches:
|
|
return self.object_patches[name]
|
|
else:
|
|
if name in self.object_patches_backup:
|
|
return self.object_patches_backup[name]
|
|
else:
|
|
return comfy.utils.get_attr(self.model, name)
|
|
|
|
def model_patches_to(self, device):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" in to:
|
|
patches = to["patches"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for i in range(len(patch_list)):
|
|
if hasattr(patch_list[i], "to"):
|
|
patch_list[i] = patch_list[i].to(device)
|
|
if "patches_replace" in to:
|
|
patches = to["patches_replace"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for k in patch_list:
|
|
if hasattr(patch_list[k], "to"):
|
|
patch_list[k] = patch_list[k].to(device)
|
|
if "model_function_wrapper" in self.model_options:
|
|
wrap_func = self.model_options["model_function_wrapper"]
|
|
if hasattr(wrap_func, "to"):
|
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
|
|
|
def model_patches_models(self):
|
|
to = self.model_options["transformer_options"]
|
|
models = []
|
|
if "patches" in to:
|
|
patches = to["patches"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for i in range(len(patch_list)):
|
|
if hasattr(patch_list[i], "models"):
|
|
models += patch_list[i].models()
|
|
if "patches_replace" in to:
|
|
patches = to["patches_replace"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for k in patch_list:
|
|
if hasattr(patch_list[k], "models"):
|
|
models += patch_list[k].models()
|
|
if "model_function_wrapper" in self.model_options:
|
|
wrap_func = self.model_options["model_function_wrapper"]
|
|
if hasattr(wrap_func, "models"):
|
|
models += wrap_func.models()
|
|
|
|
return models
|
|
|
|
def model_dtype(self):
|
|
if hasattr(self.model, "get_dtype"):
|
|
return self.model.get_dtype()
|
|
|
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
with self.use_ejected():
|
|
p = set()
|
|
model_sd = self.model.state_dict()
|
|
for k in patches:
|
|
offset = None
|
|
function = None
|
|
if isinstance(k, str):
|
|
key = k
|
|
else:
|
|
offset = k[1]
|
|
key = k[0]
|
|
if len(k) > 2:
|
|
function = k[2]
|
|
|
|
if key in model_sd:
|
|
p.add(k)
|
|
current_patches = self.patches.get(key, [])
|
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
|
self.patches[key] = current_patches
|
|
|
|
self.patches_uuid = uuid.uuid4()
|
|
return list(p)
|
|
|
|
def get_key_patches(self, filter_prefix=None):
|
|
model_sd = self.model_state_dict()
|
|
p = {}
|
|
for k in model_sd:
|
|
if filter_prefix is not None:
|
|
if not k.startswith(filter_prefix):
|
|
continue
|
|
bk = self.backup.get(k, None)
|
|
hbk = self.hook_backup.get(k, None)
|
|
weight, set_func, convert_func = get_key_weight(self.model, k)
|
|
if bk is not None:
|
|
weight = bk.weight
|
|
if hbk is not None:
|
|
weight = hbk[0]
|
|
if convert_func is None:
|
|
convert_func = lambda a, **kwargs: a
|
|
|
|
if k in self.patches:
|
|
p[k] = [(weight, convert_func)] + self.patches[k]
|
|
else:
|
|
p[k] = [(weight, convert_func)]
|
|
return p
|
|
|
|
def model_state_dict(self, filter_prefix=None):
|
|
with self.use_ejected():
|
|
sd = self.model.state_dict()
|
|
keys = list(sd.keys())
|
|
if filter_prefix is not None:
|
|
for k in keys:
|
|
if not k.startswith(filter_prefix):
|
|
sd.pop(k)
|
|
return sd
|
|
|
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
|
if key not in self.patches:
|
|
return
|
|
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
inplace_update = self.weight_inplace_update or inplace_update
|
|
|
|
if key not in self.backup:
|
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
|
|
|
if device_to is not None:
|
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
|
else:
|
|
temp_weight = weight.to(torch.float32, copy=True)
|
|
if convert_func is not None:
|
|
temp_weight = convert_func(temp_weight, inplace=True)
|
|
|
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
|
if set_func is None:
|
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
|
if inplace_update:
|
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
|
else:
|
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
|
else:
|
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
|
|
|
def pin_weight_to_device(self, key):
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
if comfy.model_management.pin_memory(weight):
|
|
self.pinned.add(key)
|
|
|
|
def unpin_weight(self, key):
|
|
if key in self.pinned:
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
comfy.model_management.unpin_memory(weight)
|
|
self.pinned.remove(key)
|
|
|
|
def unpin_all_weights(self):
|
|
for key in list(self.pinned):
|
|
self.unpin_weight(key)
|
|
|
|
def _load_list(self):
|
|
loading = []
|
|
for n, m in self.model.named_modules():
|
|
params = []
|
|
skip = False
|
|
for name, param in m.named_parameters(recurse=False):
|
|
params.append(name)
|
|
for name, param in m.named_parameters(recurse=True):
|
|
if name not in params:
|
|
skip = True # skip random weights in non leaf modules
|
|
break
|
|
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
|
module_mem = comfy.model_management.module_size(m)
|
|
module_offload_mem = module_mem
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
weight_key = "{}.weight".format(n)
|
|
bias_key = "{}.bias".format(n)
|
|
if weight_key in self.patches:
|
|
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
|
if bias_key in self.patches:
|
|
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
|
loading.append((module_offload_mem, module_mem, n, m, params))
|
|
return loading
|
|
|
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
|
with self.use_ejected():
|
|
self.unpatch_hooks()
|
|
mem_counter = 0
|
|
patch_counter = 0
|
|
lowvram_counter = 0
|
|
lowvram_mem_counter = 0
|
|
loading = self._load_list()
|
|
|
|
load_completely = []
|
|
offloaded = []
|
|
offload_buffer = 0
|
|
loading.sort(reverse=True)
|
|
for x in loading:
|
|
module_offload_mem, module_mem, n, m, params = x
|
|
|
|
lowvram_weight = False
|
|
|
|
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
|
|
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
bias_key = "{}.bias".format(n)
|
|
|
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
|
if not lowvram_fits:
|
|
offload_buffer = potential_offload
|
|
lowvram_weight = True
|
|
lowvram_counter += 1
|
|
lowvram_mem_counter += module_mem
|
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
|
continue
|
|
|
|
cast_weight = self.force_cast_weights
|
|
if lowvram_weight:
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
m.weight_function = []
|
|
m.bias_function = []
|
|
|
|
if weight_key in self.patches:
|
|
if force_patch_weights:
|
|
self.patch_weight_to_device(weight_key)
|
|
else:
|
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
|
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
|
|
patch_counter += 1
|
|
if bias_key in self.patches:
|
|
if force_patch_weights:
|
|
self.patch_weight_to_device(bias_key)
|
|
else:
|
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
|
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
|
|
patch_counter += 1
|
|
|
|
cast_weight = True
|
|
offloaded.append((module_mem, n, m, params))
|
|
else:
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
wipe_lowvram_weight(m)
|
|
|
|
if full_load or lowvram_fits:
|
|
mem_counter += module_mem
|
|
load_completely.append((module_mem, n, m, params))
|
|
else:
|
|
offload_buffer = potential_offload
|
|
|
|
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
|
m.comfy_cast_weights = True
|
|
|
|
if weight_key in self.weight_wrapper_patches:
|
|
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
|
|
|
if bias_key in self.weight_wrapper_patches:
|
|
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
|
|
|
|
mem_counter += move_weight_functions(m, device_to)
|
|
|
|
load_completely.sort(reverse=True)
|
|
for x in load_completely:
|
|
n = x[1]
|
|
m = x[2]
|
|
params = x[3]
|
|
if hasattr(m, "comfy_patched_weights"):
|
|
if m.comfy_patched_weights == True:
|
|
continue
|
|
|
|
for param in params:
|
|
key = "{}.{}".format(n, param)
|
|
self.unpin_weight(key)
|
|
self.patch_weight_to_device(key, device_to=device_to)
|
|
|
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
|
m.comfy_patched_weights = True
|
|
|
|
for x in load_completely:
|
|
x[2].to(device_to)
|
|
|
|
for x in offloaded:
|
|
n = x[1]
|
|
params = x[3]
|
|
for param in params:
|
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
|
|
|
if lowvram_counter > 0:
|
|
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
|
self.model.model_lowvram = True
|
|
else:
|
|
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
|
self.model.model_lowvram = False
|
|
if full_load:
|
|
self.model.to(device_to)
|
|
mem_counter = self.model_size()
|
|
|
|
self.model.lowvram_patch_counter += patch_counter
|
|
self.model.device = device_to
|
|
self.model.model_loaded_weight_memory = mem_counter
|
|
self.model.model_offload_buffer_memory = offload_buffer
|
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
|
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
|
|
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
|
|
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
|
with self.use_ejected():
|
|
for k in self.object_patches:
|
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
|
if k not in self.object_patches_backup:
|
|
self.object_patches_backup[k] = old
|
|
|
|
if lowvram_model_memory == 0:
|
|
full_load = True
|
|
else:
|
|
full_load = False
|
|
|
|
if load_weights:
|
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
|
self.inject_model()
|
|
return self.model
|
|
|
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
|
self.eject_model()
|
|
if unpatch_weights:
|
|
self.unpatch_hooks()
|
|
self.unpin_all_weights()
|
|
if self.model.model_lowvram:
|
|
for m in self.model.modules():
|
|
move_weight_functions(m, device_to)
|
|
wipe_lowvram_weight(m)
|
|
|
|
self.model.model_lowvram = False
|
|
self.model.lowvram_patch_counter = 0
|
|
|
|
keys = list(self.backup.keys())
|
|
|
|
for k in keys:
|
|
bk = self.backup[k]
|
|
if bk.inplace_update:
|
|
comfy.utils.copy_to_param(self.model, k, bk.weight)
|
|
else:
|
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
|
|
|
self.model.current_weight_patches_uuid = None
|
|
self.backup.clear()
|
|
|
|
if device_to is not None:
|
|
self.model.to(device_to)
|
|
self.model.device = device_to
|
|
self.model.model_loaded_weight_memory = 0
|
|
self.model.model_offload_buffer_memory = 0
|
|
|
|
for m in self.model.modules():
|
|
if hasattr(m, "comfy_patched_weights"):
|
|
del m.comfy_patched_weights
|
|
|
|
keys = list(self.object_patches_backup.keys())
|
|
for k in keys:
|
|
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
|
|
|
self.object_patches_backup.clear()
|
|
|
|
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
|
with self.use_ejected():
|
|
hooks_unpatched = False
|
|
memory_freed = 0
|
|
patch_counter = 0
|
|
unload_list = self._load_list()
|
|
unload_list.sort()
|
|
offload_buffer = self.model.model_offload_buffer_memory
|
|
|
|
for unload in unload_list:
|
|
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
|
break
|
|
module_offload_mem, module_mem, n, m, params = unload
|
|
|
|
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
|
|
|
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
|
move_weight = True
|
|
for param in params:
|
|
key = "{}.{}".format(n, param)
|
|
bk = self.backup.get(key, None)
|
|
if bk is not None:
|
|
if not lowvram_possible:
|
|
move_weight = False
|
|
break
|
|
|
|
if not hooks_unpatched:
|
|
self.unpatch_hooks()
|
|
hooks_unpatched = True
|
|
|
|
if bk.inplace_update:
|
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
|
else:
|
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
|
self.backup.pop(key)
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
bias_key = "{}.bias".format(n)
|
|
if move_weight:
|
|
cast_weight = self.force_cast_weights
|
|
m.to(device_to)
|
|
module_mem += move_weight_functions(m, device_to)
|
|
if lowvram_possible:
|
|
if weight_key in self.patches:
|
|
if force_patch_weights:
|
|
self.patch_weight_to_device(weight_key)
|
|
else:
|
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
|
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
|
patch_counter += 1
|
|
if bias_key in self.patches:
|
|
if force_patch_weights:
|
|
self.patch_weight_to_device(bias_key)
|
|
else:
|
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
|
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
|
patch_counter += 1
|
|
cast_weight = True
|
|
|
|
if cast_weight:
|
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
|
m.comfy_cast_weights = True
|
|
m.comfy_patched_weights = False
|
|
memory_freed += module_mem
|
|
offload_buffer = max(offload_buffer, potential_offload)
|
|
logging.debug("freed {}".format(n))
|
|
|
|
for param in params:
|
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
|
|
|
|
|
self.model.model_lowvram = True
|
|
self.model.lowvram_patch_counter += patch_counter
|
|
self.model.model_loaded_weight_memory -= memory_freed
|
|
self.model.model_offload_buffer_memory = offload_buffer
|
|
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
|
return memory_freed
|
|
|
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
|
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
|
# TODO: force_patch_weights should not unload + reload full model
|
|
used = self.model.model_loaded_weight_memory
|
|
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
|
if unpatch_weights:
|
|
extra_memory += (used - self.model.model_loaded_weight_memory)
|
|
|
|
self.patch_model(load_weights=False)
|
|
if extra_memory < 0 and not unpatch_weights:
|
|
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
|
return 0
|
|
full_load = False
|
|
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
|
return 0
|
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
|
full_load = True
|
|
current_used = self.model.model_loaded_weight_memory
|
|
try:
|
|
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
|
except Exception as e:
|
|
self.detach()
|
|
raise e
|
|
|
|
return self.model.model_loaded_weight_memory - current_used
|
|
|
|
def detach(self, unpatch_all=True):
|
|
self.eject_model()
|
|
self.model_patches_to(self.offload_device)
|
|
if unpatch_all:
|
|
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
|
callback(self, unpatch_all)
|
|
return self.model
|
|
|
|
def current_loaded_device(self):
|
|
return self.model.device
|
|
|
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
|
logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
|
|
|
def cleanup(self):
|
|
self.clean_hooks()
|
|
if hasattr(self.model, "current_patcher"):
|
|
self.model.current_patcher = None
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
|
|
callback(self)
|
|
|
|
def add_callback(self, call_type: str, callback: Callable):
|
|
self.add_callback_with_key(call_type, None, callback)
|
|
|
|
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
|
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
|
c.append(callback)
|
|
|
|
def remove_callbacks_with_key(self, call_type: str, key: str):
|
|
c = self.callbacks.get(call_type, {})
|
|
if key in c:
|
|
c.pop(key)
|
|
|
|
def get_callbacks(self, call_type: str, key: str):
|
|
return self.callbacks.get(call_type, {}).get(key, [])
|
|
|
|
def get_all_callbacks(self, call_type: str):
|
|
c_list = []
|
|
for c in self.callbacks.get(call_type, {}).values():
|
|
c_list.extend(c)
|
|
return c_list
|
|
|
|
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
|
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
|
|
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
|
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
|
w.append(wrapper)
|
|
|
|
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
|
w = self.wrappers.get(wrapper_type, {})
|
|
if key in w:
|
|
w.pop(key)
|
|
|
|
def get_wrappers(self, wrapper_type: str, key: str):
|
|
return self.wrappers.get(wrapper_type, {}).get(key, [])
|
|
|
|
def get_all_wrappers(self, wrapper_type: str):
|
|
w_list = []
|
|
for w in self.wrappers.get(wrapper_type, {}).values():
|
|
w_list.extend(w)
|
|
return w_list
|
|
|
|
def set_attachments(self, key: str, attachment):
|
|
self.attachments[key] = attachment
|
|
|
|
def remove_attachments(self, key: str):
|
|
if key in self.attachments:
|
|
self.attachments.pop(key)
|
|
|
|
def get_attachment(self, key: str):
|
|
return self.attachments.get(key, None)
|
|
|
|
def set_injections(self, key: str, injections: list[PatcherInjection]):
|
|
self.injections[key] = injections
|
|
|
|
def remove_injections(self, key: str):
|
|
if key in self.injections:
|
|
self.injections.pop(key)
|
|
|
|
def get_injections(self, key: str):
|
|
return self.injections.get(key, None)
|
|
|
|
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
|
self.additional_models[key] = models
|
|
|
|
def remove_additional_models(self, key: str):
|
|
if key in self.additional_models:
|
|
self.additional_models.pop(key)
|
|
|
|
def get_additional_models_with_key(self, key: str):
|
|
return self.additional_models.get(key, [])
|
|
|
|
def get_additional_models(self):
|
|
all_models = []
|
|
for models in self.additional_models.values():
|
|
all_models.extend(models)
|
|
return all_models
|
|
|
|
def get_nested_additional_models(self):
|
|
def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]):
|
|
'''Make sure circular references do not cause infinite recursion.'''
|
|
next_models = []
|
|
for model in prev_models:
|
|
candidates = model.get_additional_models()
|
|
for c in candidates:
|
|
if c not in cache_set:
|
|
next_models.append(c)
|
|
cache_set.add(c)
|
|
if len(next_models) == 0:
|
|
return prev_models
|
|
return prev_models + _evaluate_sub_additional_models(next_models, cache_set)
|
|
|
|
all_models = self.get_additional_models()
|
|
models_set = set(all_models)
|
|
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
|
|
return real_all_models
|
|
|
|
def use_ejected(self, skip_and_inject_on_exit_only=False):
|
|
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)
|
|
|
|
def inject_model(self):
|
|
if self.is_injected or self.skip_injection:
|
|
return
|
|
for injections in self.injections.values():
|
|
for inj in injections:
|
|
inj.inject(self)
|
|
self.is_injected = True
|
|
if self.is_injected:
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL):
|
|
callback(self)
|
|
|
|
def eject_model(self):
|
|
if not self.is_injected:
|
|
return
|
|
for injections in self.injections.values():
|
|
for inj in injections:
|
|
inj.eject(self)
|
|
self.is_injected = False
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL):
|
|
callback(self)
|
|
|
|
def pre_run(self):
|
|
if hasattr(self.model, "current_patcher"):
|
|
self.model.current_patcher = self
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
|
callback(self)
|
|
|
|
def prepare_state(self, timestep):
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
|
callback(self, timestep)
|
|
|
|
def restore_hook_patches(self):
|
|
if self.hook_patches_backup is not None:
|
|
self.hook_patches = self.hook_patches_backup
|
|
self.hook_patches_backup = None
|
|
|
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
|
self.hook_mode = hook_mode
|
|
|
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
|
curr_t = t[0]
|
|
reset_current_hooks = False
|
|
transformer_options = model_options.get("transformer_options", {})
|
|
for hook in hook_group.hooks:
|
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
|
# this will cause the weights to be recalculated when sampling
|
|
if changed:
|
|
# reset current_hooks if contains hook that changed
|
|
if self.current_hooks is not None:
|
|
for current_hook in self.current_hooks.hooks:
|
|
if current_hook == hook:
|
|
reset_current_hooks = True
|
|
break
|
|
for cached_group in list(self.cached_hook_patches.keys()):
|
|
if cached_group.contains(hook):
|
|
self.cached_hook_patches.pop(cached_group)
|
|
if reset_current_hooks:
|
|
self.patch_hooks(None)
|
|
|
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
|
registered: comfy.hooks.HookGroup = None):
|
|
self.restore_hook_patches()
|
|
if registered is None:
|
|
registered = comfy.hooks.HookGroup()
|
|
# handle WeightHooks
|
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
|
if hook.hook_ref not in self.hook_patches:
|
|
weight_hooks_to_register.append(hook)
|
|
else:
|
|
registered.add(hook)
|
|
if len(weight_hooks_to_register) > 0:
|
|
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
|
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
|
for hook in weight_hooks_to_register:
|
|
hook.add_hook_patches(self, model_options, target_dict, registered)
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
|
callback(self, hooks, target_dict, model_options, registered)
|
|
return registered
|
|
|
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
|
with self.use_ejected():
|
|
# NOTE: this mirrors behavior of add_patches func
|
|
current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {})
|
|
p = set()
|
|
model_sd = self.model.state_dict()
|
|
for k in patches:
|
|
offset = None
|
|
function = None
|
|
if isinstance(k, str):
|
|
key = k
|
|
else:
|
|
offset = k[1]
|
|
key = k[0]
|
|
if len(k) > 2:
|
|
function = k[2]
|
|
|
|
if key in model_sd:
|
|
p.add(k)
|
|
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
|
current_hook_patches[key] = current_patches
|
|
self.hook_patches[hook.hook_ref] = current_hook_patches
|
|
# since should care about these patches too to determine if same model, reroll patches_uuid
|
|
self.patches_uuid = uuid.uuid4()
|
|
return list(p)
|
|
|
|
def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
|
|
# combined_patches will contain weights of all relevant hooks, per key
|
|
combined_patches = {}
|
|
if hooks is not None:
|
|
for hook in hooks.hooks:
|
|
hook_patches: dict = self.hook_patches.get(hook.hook_ref, {})
|
|
for key in hook_patches.keys():
|
|
current_patches: list[tuple] = combined_patches.get(key, [])
|
|
if math.isclose(hook.strength, 1.0):
|
|
current_patches.extend(hook_patches[key])
|
|
else:
|
|
# patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model)
|
|
for patch in hook_patches[key]:
|
|
new_patch = list(patch)
|
|
new_patch[0] *= hook.strength
|
|
current_patches.append(tuple(new_patch))
|
|
combined_patches[key] = current_patches
|
|
return combined_patches
|
|
|
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
|
# TODO: return transformer_options dict with any additions from hooks
|
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
|
self.patch_hooks(hooks=hooks)
|
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
|
callback(self, hooks)
|
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
|
|
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
|
with self.use_ejected():
|
|
if hooks is not None:
|
|
model_sd_keys = list(self.model_state_dict().keys())
|
|
memory_counter = None
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
|
|
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
|
|
minimum=comfy.model_management.minimum_inference_memory()*2)
|
|
# if have cached weights for hooks, use it
|
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
|
if cached_weights is not None:
|
|
model_sd_keys_set = set(model_sd_keys)
|
|
for key in cached_weights:
|
|
if key not in model_sd_keys:
|
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
|
continue
|
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
|
model_sd_keys_set.remove(key)
|
|
self.unpatch_hooks(model_sd_keys_set)
|
|
else:
|
|
self.unpatch_hooks()
|
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
|
original_weights = None
|
|
if len(relevant_patches) > 0:
|
|
original_weights = self.get_key_patches()
|
|
for key in relevant_patches:
|
|
if key not in model_sd_keys:
|
|
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
|
continue
|
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
|
memory_counter=memory_counter)
|
|
else:
|
|
self.unpatch_hooks()
|
|
self.current_hooks = hooks
|
|
|
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
|
if key not in self.hook_backup:
|
|
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
|
|
target_device = self.offload_device
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
used = memory_counter.use(weight)
|
|
if used:
|
|
target_device = weight.device
|
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
|
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
|
|
|
def clear_cached_hook_weights(self):
|
|
self.cached_hook_patches.clear()
|
|
self.patch_hooks(None)
|
|
|
|
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
|
if key not in combined_patches:
|
|
return
|
|
|
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
|
weight: torch.Tensor
|
|
if key not in self.hook_backup:
|
|
target_device = self.offload_device
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
used = memory_counter.use(weight)
|
|
if used:
|
|
target_device = weight.device
|
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
|
# TODO: properly handle LowVramPatch, if it ends up an issue
|
|
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
|
|
if convert_func is not None:
|
|
temp_weight = convert_func(temp_weight, inplace=True)
|
|
|
|
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
|
temp_weight,
|
|
key, original_weights=original_weights)
|
|
del original_weights[key]
|
|
if set_func is None:
|
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
|
else:
|
|
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
|
# TODO: disable caching if not enough system RAM to do so
|
|
target_device = self.offload_device
|
|
used = memory_counter.use(weight)
|
|
if used:
|
|
target_device = weight.device
|
|
self.cached_hook_patches.setdefault(hooks, {})
|
|
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
|
|
del temp_weight
|
|
del out_weight
|
|
del weight
|
|
|
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
|
with self.use_ejected():
|
|
if len(self.hook_backup) == 0:
|
|
self.current_hooks = None
|
|
return
|
|
keys = list(self.hook_backup.keys())
|
|
if whitelist_keys_set:
|
|
for k in keys:
|
|
if k in whitelist_keys_set:
|
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
|
self.hook_backup.pop(k)
|
|
else:
|
|
for k in keys:
|
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
|
|
|
self.hook_backup.clear()
|
|
self.current_hooks = None
|
|
|
|
def clean_hooks(self):
|
|
self.unpatch_hooks()
|
|
self.clear_cached_hook_weights()
|
|
|
|
def __del__(self):
|
|
self.unpin_all_weights()
|
|
self.detach(unpatch_all=False)
|
|
|