Merge e65e642fbe9f19e10ee44f322b2f971c0b9b525f into 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf

This commit is contained in:
rattus 2025-12-07 07:51:04 -05:00 committed by GitHub
commit 36f621f7e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 92 additions and 37 deletions

View File

@ -305,6 +305,7 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
comfy.model_management.free_ram(state_dict=to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

View File

@ -445,6 +445,20 @@ try:
except:
logging.warning("Could not pick default device.")
current_ram_listeners = set()
def register_ram_listener(listener):
current_ram_listeners.add(listener)
def unregister_ram_listener(listener):
current_ram_listeners.discard(listener)
def free_ram(extra_ram=0, state_dict={}):
for tensor in state_dict.values():
if isinstance(tensor, torch.Tensor):
extra_ram += tensor.numel() * tensor.element_size()
for listener in current_ram_listeners:
listener.free_ram(extra_ram)
current_loaded_models = []
@ -521,12 +535,18 @@ class LoadedModel:
return False
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if self.model is None:
return True
logging.debug(f"Unloading {self.model.model.__class__.__name__}")
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
offload_modules(modules_to_offload, self.model.offload_device)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
if self.model is not None:
modules_to_offload = self.model.detach(unpatch_weights)
offload_modules(modules_to_offload, self.model.offload_device)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@ -543,7 +563,7 @@ class LoadedModel:
self._patcher_finalizer.detach()
def is_dead(self):
return self.real_model() is not None and self.model is None
return self.real_model is not None and self.real_model() is not None and self.model is None
def use_more_memory(extra_memory, loaded_models, device):
@ -578,6 +598,13 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def offload_modules(modules, offload_device):
for module in modules:
if module() is None:
continue
module().to(offload_device)
free_ram()
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
@ -588,23 +615,25 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model))
shift_model.currently_used = False
for x in sorted(can_unload):
i = x[-1]
shift_model = x[-1]
i = x[-2]
memory_to_free = None
if not DISABLE_SMART_MEMORY:
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
if shift_model.model_unload(memory_to_free):
unloaded_model.append((i, shift_model))
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
for i, shift_model in sorted(unloaded_model, reverse=True):
unloaded_models.append(shift_model)
if shift_model in current_loaded_models:
current_loaded_models.remove(shift_model)
if len(unloaded_model) > 0:
soft_empty_cache()
@ -739,7 +768,7 @@ def cleanup_models_gc():
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None:
if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None:
to_delete = [i] + to_delete
for i in to_delete:

View File

@ -24,6 +24,7 @@ import inspect
import logging
import math
import uuid
import weakref
from typing import Callable, Optional
import torch
@ -817,6 +818,7 @@ class ModelPatcher:
def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model()
modules_to_move = []
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
@ -841,7 +843,8 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
modules_to_move.append(weakref.ref(self.model))
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
@ -855,12 +858,14 @@ class ModelPatcher:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear()
return modules_to_move
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
modules_to_move = []
unload_list = self._load_list()
unload_list.sort()
@ -901,7 +906,7 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
modules_to_move.append(weakref.ref(m))
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
@ -939,20 +944,22 @@ class ModelPatcher:
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
return memory_freed, modules_to_move
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
with self.use_ejected(skip_and_inject_on_exit_only=True):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
_, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
@ -964,7 +971,7 @@ class ModelPatcher:
try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e:
self.detach()
comfy.model_management.offload_modules(self.detach(), self.offload_device)
raise e
return self.model.model_loaded_weight_memory - current_used
@ -972,11 +979,12 @@ class ModelPatcher:
def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device)
modules_to_offload = []
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model
return modules_to_offload
def current_loaded_device(self):
return self.model.device

View File

@ -284,6 +284,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
comfy.model_management.free_ram(state_dict=sd)
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@ -651,6 +652,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
comfy.model_management.free_ram(state_dict=sd)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@ -961,6 +963,7 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
comfy.model_management.free_ram(state_dict=model_data)
model.load_state_dict(model_data)
return StyleModel(model)

View File

@ -193,7 +193,7 @@ class BasicCache:
self._clean_cache()
self._clean_subcaches()
def poll(self, **kwargs):
def free_ram(self, *args, **kwargs):
pass
def _set_immediate(self, node_id, value):
@ -284,7 +284,7 @@ class NullCache:
def clean_unused(self):
pass
def poll(self, **kwargs):
def free_ram(self, *args, **kwargs):
pass
def get(self, node_id):
@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
def __init__(self, key_class, min_headroom=4.0):
super().__init__(key_class, 0)
self.timestamps = {}
self.min_headroom = min_headroom
def clean_unused(self):
self._clean_subcaches()
@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
def _build_clean_list(self):
clean_list = []
for key, (outputs, _), in self.cache.items():
for key, (_, outputs), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache):
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
return clean_list
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
def free_ram(self, extra_ram=0):
headroom_target = self.min_headroom + (extra_ram / (1024**3))
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > headroom_target:
return
gc.collect()
if _ram_gb() > headroom_target:
return
clean_list = self._build_clean_list()
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@ -107,7 +107,7 @@ class CacheSet:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
cache_ram = cache_args.get("ram", 4.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
@ -129,7 +129,7 @@ class CacheSet:
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
@ -618,13 +618,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class PromptExecutor:
def __init__(self, server, cache_type=False, cache_args=None):
self.caches = None
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
if self.caches is not None:
for cache in self.caches.all:
comfy.model_management.unregister_ram_listener(cache)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
for cache in self.caches.all:
comfy.model_management.register_ram_listener(cache)
self.status_messages = []
self.success = True
@ -722,7 +730,7 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
self.caches.outputs.free_ram()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)