mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Merge e65e642fbe9f19e10ee44f322b2f971c0b9b525f into 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf
This commit is contained in:
commit
36f621f7e2
@ -305,6 +305,7 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
comfy.model_management.free_ram(state_dict=to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@ -445,6 +445,20 @@ try:
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
current_ram_listeners = set()
|
||||
|
||||
def register_ram_listener(listener):
|
||||
current_ram_listeners.add(listener)
|
||||
|
||||
def unregister_ram_listener(listener):
|
||||
current_ram_listeners.discard(listener)
|
||||
|
||||
def free_ram(extra_ram=0, state_dict={}):
|
||||
for tensor in state_dict.values():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
extra_ram += tensor.numel() * tensor.element_size()
|
||||
for listener in current_ram_listeners:
|
||||
listener.free_ram(extra_ram)
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
@ -521,12 +535,18 @@ class LoadedModel:
|
||||
return False
|
||||
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if self.model is None:
|
||||
return True
|
||||
logging.debug(f"Unloading {self.model.model.__class__.__name__}")
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
offload_modules(modules_to_offload, self.model.offload_device)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
if self.model is not None:
|
||||
modules_to_offload = self.model.detach(unpatch_weights)
|
||||
offload_modules(modules_to_offload, self.model.offload_device)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
@ -543,7 +563,7 @@ class LoadedModel:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
return self.real_model is not None and self.real_model() is not None and self.model is None
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
@ -578,6 +598,13 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def offload_modules(modules, offload_device):
|
||||
for module in modules:
|
||||
if module() is None:
|
||||
continue
|
||||
module().to(offload_device)
|
||||
free_ram()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
@ -588,23 +615,25 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
shift_model = x[-1]
|
||||
i = x[-2]
|
||||
memory_to_free = None
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
if shift_model.model_unload(memory_to_free):
|
||||
unloaded_model.append((i, shift_model))
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
for i, shift_model in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(shift_model)
|
||||
if shift_model in current_loaded_models:
|
||||
current_loaded_models.remove(shift_model)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@ -739,7 +768,7 @@ def cleanup_models_gc():
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
|
||||
@ -24,6 +24,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
import weakref
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -817,6 +818,7 @@ class ModelPatcher:
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
self.eject_model()
|
||||
modules_to_move = []
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
@ -841,7 +843,8 @@ class ModelPatcher:
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
|
||||
modules_to_move.append(weakref.ref(self.model))
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
@ -855,12 +858,14 @@ class ModelPatcher:
|
||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
return modules_to_move
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
modules_to_move = []
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
|
||||
@ -901,7 +906,7 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
modules_to_move.append(weakref.ref(m))
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
@ -939,20 +944,22 @@ class ModelPatcher:
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
return memory_freed, modules_to_move
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
||||
# TODO: force_patch_weights should not unload + reload full model
|
||||
used = self.model.model_loaded_weight_memory
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
|
||||
if unpatch_weights:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
_, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
@ -964,7 +971,7 @@ class ModelPatcher:
|
||||
try:
|
||||
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||
except Exception as e:
|
||||
self.detach()
|
||||
comfy.model_management.offload_modules(self.detach(), self.offload_device)
|
||||
raise e
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
@ -972,11 +979,12 @@ class ModelPatcher:
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
modules_to_offload = []
|
||||
if unpatch_all:
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
||||
callback(self, unpatch_all)
|
||||
return self.model
|
||||
return modules_to_offload
|
||||
|
||||
def current_loaded_device(self):
|
||||
return self.model.device
|
||||
|
||||
@ -284,6 +284,7 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
comfy.model_management.free_ram(state_dict=sd)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
@ -651,6 +652,7 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
comfy.model_management.free_ram(state_dict=sd)
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
@ -961,6 +963,7 @@ def load_style_model(ckpt_path):
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
comfy.model_management.free_ram(state_dict=model_data)
|
||||
model.load_state_dict(model_data)
|
||||
return StyleModel(model)
|
||||
|
||||
|
||||
@ -193,7 +193,7 @@ class BasicCache:
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def poll(self, **kwargs):
|
||||
def free_ram(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
@ -284,7 +284,7 @@ class NullCache:
|
||||
def clean_unused(self):
|
||||
pass
|
||||
|
||||
def poll(self, **kwargs):
|
||||
def free_ram(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
def __init__(self, key_class, min_headroom=4.0):
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
self.min_headroom = min_headroom
|
||||
|
||||
def clean_unused(self):
|
||||
self._clean_subcaches()
|
||||
@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
|
||||
def _build_clean_list(self):
|
||||
clean_list = []
|
||||
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
for key, (_, outputs), in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache):
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
return clean_list
|
||||
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
def free_ram(self, extra_ram=0):
|
||||
headroom_target = self.min_headroom + (extra_ram / (1024**3))
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > headroom_target:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > headroom_target:
|
||||
return
|
||||
|
||||
clean_list = self._build_clean_list()
|
||||
|
||||
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
gc.collect()
|
||||
|
||||
14
execution.py
14
execution.py
@ -107,7 +107,7 @@ class CacheSet:
|
||||
self.init_null_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.RAM_PRESSURE:
|
||||
cache_ram = cache_args.get("ram", 16.0)
|
||||
cache_ram = cache_args.get("ram", 4.0)
|
||||
self.init_ram_cache(cache_ram)
|
||||
logging.info("Using RAM pressure cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
@ -129,7 +129,7 @@ class CacheSet:
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_ram_cache(self, min_headroom):
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
@ -618,13 +618,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, cache_type=False, cache_args=None):
|
||||
self.caches = None
|
||||
self.cache_args = cache_args
|
||||
self.cache_type = cache_type
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
if self.caches is not None:
|
||||
for cache in self.caches.all:
|
||||
comfy.model_management.unregister_ram_listener(cache)
|
||||
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
|
||||
for cache in self.caches.all:
|
||||
comfy.model_management.register_ram_listener(cache)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
@ -722,7 +730,7 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
self.caches.outputs.free_ram()
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user