From 68053b1180a31bcca9cab18fb042e8d3117a5c4c Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 14 Nov 2025 15:31:57 +1000 Subject: [PATCH 1/7] caching: build headroom into the RAM cache move the headroom logic into the RAM cache to make this a little easier to call to "free me some RAM". Rename the API to free_ram(). Split off the clean_list creation to a completely separate function to avoid any stray strong reference to the content-to-be-freed on the stack. --- comfy_execution/caching.py | 36 +++++++++++++++++++++--------------- execution.py | 6 +++--- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..43f882469 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -193,7 +193,7 @@ class BasicCache: self._clean_cache() self._clean_subcaches() - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def _set_immediate(self, node_id, value): @@ -284,7 +284,7 @@ class NullCache: def clean_unused(self): pass - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def get(self, node_id): @@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): - def __init__(self, key_class): + def __init__(self, key_class, min_headroom=4.0): super().__init__(key_class, 0) self.timestamps = {} + self.min_headroom = min_headroom def clean_unused(self): self._clean_subcaches() @@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() return super().get(node_id) - def poll(self, ram_headroom): - def _ram_gb(): - return psutil.virtual_memory().available / (1024**3) - - if _ram_gb() > ram_headroom: - return - gc.collect() - if _ram_gb() > ram_headroom: - return - + def _build_clean_list(self): clean_list = [] - for key, (outputs, _), in self.cache.items(): + for key, (_, outputs), in self.cache.items(): oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE @@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache): #In the case where we have no information on the node ram usage at all, #break OOM score ties on the last touch timestamp (pure LRU) bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + return clean_list - while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + def free_ram(self, extra_ram=0): + headroom_target = self.min_headroom + (extra_ram / (1024**3)) + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > headroom_target: + return + gc.collect() + if _ram_gb() > headroom_target: + return + + clean_list = self._build_clean_list() + + while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list: _, _, key = clean_list.pop() del self.cache[key] gc.collect() diff --git a/execution.py b/execution.py index 17c77beab..44e3bb65c 100644 --- a/execution.py +++ b/execution.py @@ -107,7 +107,7 @@ class CacheSet: self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.RAM_PRESSURE: - cache_ram = cache_args.get("ram", 16.0) + cache_ram = cache_args.get("ram", 4.0) self.init_ram_cache(cache_ram) logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: @@ -129,7 +129,7 @@ class CacheSet: self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -717,7 +717,7 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + self.caches.outputs.free_ram() else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) From 62a26225915c8d399c1fcdee61a34af8636f4bf2 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 18 Nov 2025 09:15:44 +1000 Subject: [PATCH 2/7] mm: Add free_ram() Add the free_ram() API and a means to install implementations of the freer (I.E. the RAM cache). --- comfy/model_management.py | 14 ++++++++++++++ execution.py | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe..6222c19ae 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -445,6 +445,20 @@ try: except: logging.warning("Could not pick default device.") +current_ram_listeners = set() + +def register_ram_listener(listener): + current_ram_listeners.add(listener) + +def unregister_ram_listener(listener): + current_ram_listeners.discard(listener) + +def free_ram(extra_ram=0, state_dict={}): + for tensor in state_dict.values(): + if isinstance(tensor, torch.Tensor): + extra_ram += tensor.numel() * tensor.element_size() + for listener in current_ram_listeners: + listener.free_ram(extra_ram) current_loaded_models = [] diff --git a/execution.py b/execution.py index 44e3bb65c..dd5fc8baf 100644 --- a/execution.py +++ b/execution.py @@ -613,13 +613,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class PromptExecutor: def __init__(self, server, cache_type=False, cache_args=None): + self.caches = None self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): + if self.caches is not None: + for cache in self.caches.all: + comfy.model_management.unregister_ram_listener(cache) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + + for cache in self.caches.all: + comfy.model_management.register_ram_listener(cache) self.status_messages = [] self.success = True From 4a83a9bc0e6da93af2fe75457768fc1e2a0238ad Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 18 Nov 2025 09:17:10 +1000 Subject: [PATCH 3/7] sd: Free RAM on main model load --- comfy/model_base.py | 1 + comfy/sd.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9b76c285e..76dc3e370 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -304,6 +304,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) + comfy.model_management.free_ram(state_dict=to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/sd.py b/comfy/sd.py index f9e5efab5..bfa63debc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -260,6 +260,7 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: + comfy.model_management.free_ram(state_dict=sd) return self.cond_stage_model.load_state_dict(sd, strict=False) else: return self.cond_stage_model.load_sd(sd) @@ -625,6 +626,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + comfy.model_management.free_ram(state_dict=sd) m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: logging.warning("Missing VAE keys {}".format(m)) @@ -933,6 +935,7 @@ def load_style_model(ckpt_path): model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) + comfy.model_management.free_ram(state_dict=model_data) model.load_state_dict(model_data) return StyleModel(model) From 7af5bf49e4bef048bc86af9717713b280ce6532d Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 18 Nov 2025 10:10:16 +1000 Subject: [PATCH 4/7] mm: make garbage collector null safe on real_model currently this hard assumes that the caller of model_unload will keep current_loaded_models in sync. With RAMPressureCache its possible to have the garbage collector occur in the middle of the model free process which can split these two steps. --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 6222c19ae..1c5796410 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -557,7 +557,7 @@ class LoadedModel: self._patcher_finalizer.detach() def is_dead(self): - return self.real_model() is not None and self.model is None + return self.real_model is not None and self.real_model() is not None and self.model is None def use_more_memory(extra_memory, loaded_models, device): @@ -753,7 +753,7 @@ def cleanup_models_gc(): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: + if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None: to_delete = [i] + to_delete for i in to_delete: From 07d7cd9618eaa202be5c26fa7bcf485ecee0713d Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 18 Nov 2025 09:32:28 +1000 Subject: [PATCH 5/7] mm: dont use list of indexes for unload list work list This is currently put together as a list of indexes assuming the current_loaded_models doesn't change. However we might need to pruge a model as part of the offload process which means this list can change in the middle of the freeing process. handle by taking independent refs to the LoadedModel objects and dong safe by-value deletion of current_loaded_models. --- comfy/model_management.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 1c5796410..18a700905 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -602,23 +602,26 @@ def free_memory(memory_required, device, keep_loaded=[]): shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): - can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model)) shift_model.currently_used = False for x in sorted(can_unload): - i = x[-1] + shift_model = x[-1] + i = x[-2] memory_to_free = None if not DISABLE_SMART_MEMORY: free_mem = get_free_memory(device) if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): - unloaded_model.append(i) + logging.info(f"Unloading {shift_model.model.model.__class__.__name__}") + if shift_model.model_unload(memory_to_free): + unloaded_model.append((i, shift_model)) - for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + for i, shift_model in sorted(unloaded_model, reverse=True): + unloaded_models.append(shift_model) + if shift_model in current_loaded_models: + current_loaded_models.remove(shift_model) if len(unloaded_model) > 0: soft_empty_cache() From 99bed5e19fe831bb80b26420e96c646d26c0fc9b Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 18 Nov 2025 09:42:56 +1000 Subject: [PATCH 6/7] mm: make model offloading deffered with weakrefs RAMPressure caching may ned to purge the same model that you are currently trying to offload for VRAM freeing. In this case, RAMPressure cache takes priority and needs to be able to pull the trigger on dumping the whole model and freeing the ModelPatcher in question. To do this, defer the actual tranfer of model weights from GPU to RAM to model_management state and not as part of ModelPatcher. This is dones as a list of weakrefs. If RAM cache decides to free to model you are currently unloading, then the ModelPatcher and refs simply dissappear in the middle of the unloading process, and both RAM and VRAM will be freed. The unpatcher now queues the individual leaf modules to be offloaded one-by-one so that RAM levels can be monitored. Note that the UnloadPartially that is potentially done as part of a load will not be freeable this way, however it shouldn't be anyway as that is the currently active model and RAM cache cannot save you if you cant even fit the one model you are currently trying to use. --- comfy/model_management.py | 16 ++++++++++++++-- comfy/model_patcher.py | 24 ++++++++++++++++-------- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 18a700905..5a7f23e30 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -535,12 +535,17 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): + if self.model is None: + return True if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free) + offload_modules(modules_to_offload, self.model.offload_device) if freed >= memory_to_free: return False - self.model.detach(unpatch_weights) + if self.model is not None: + modules_to_offload = self.model.detach(unpatch_weights) + offload_modules(modules_to_offload, self.model.offload_device) self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -592,6 +597,13 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() +def offload_modules(modules, offload_device): + for module in modules: + if module() is None: + continue + module().to(offload_device) + free_ram() + def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() unloaded_model = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..078c23019 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import inspect import logging import math import uuid +import weakref from typing import Callable, Optional import torch @@ -830,6 +831,7 @@ class ModelPatcher: def unpatch_model(self, device_to=None, unpatch_weights=True): self.eject_model() + modules_to_move = [] if unpatch_weights: self.unpatch_hooks() self.unpin_all_weights() @@ -854,7 +856,8 @@ class ModelPatcher: self.backup.clear() if device_to is not None: - self.model.to(device_to) + modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ] + modules_to_move.append(weakref.ref(self.model)) self.model.device = device_to self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -868,12 +871,14 @@ class ModelPatcher: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup.clear() + return modules_to_move def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 patch_counter = 0 + modules_to_move = [] unload_list = self._load_list() unload_list.sort() offload_buffer = self.model.model_offload_buffer_memory @@ -910,7 +915,7 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + modules_to_move.append(weakref.ref(m)) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: @@ -946,20 +951,22 @@ class ModelPatcher: self.model.model_loaded_weight_memory -= memory_freed self.model.model_offload_buffer_memory = offload_buffer logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) - return memory_freed + return memory_freed, modules_to_move def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): with self.use_ejected(skip_and_inject_on_exit_only=True): unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) # TODO: force_patch_weights should not unload + reload full model used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) if unpatch_weights: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) if extra_memory < 0 and not unpatch_weights: - self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + _, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) return 0 full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: @@ -971,7 +978,7 @@ class ModelPatcher: try: self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) except Exception as e: - self.detach() + comfy.model_management.offload_modules(self.detach(), self.offload_device) raise e return self.model.model_loaded_weight_memory - current_used @@ -979,11 +986,12 @@ class ModelPatcher: def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) + modules_to_offload = [] if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): callback(self, unpatch_all) - return self.model + return modules_to_offload def current_loaded_device(self): return self.model.device From e65e642fbe9f19e10ee44f322b2f971c0b9b525f Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 27 Nov 2025 17:58:39 +1000 Subject: [PATCH 7/7] mm: fix debug message --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5a7f23e30..3a5a6c4e3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -537,6 +537,7 @@ class LoadedModel: def model_unload(self, memory_to_free=None, unpatch_weights=True): if self.model is None: return True + logging.debug(f"Unloading {self.model.model.__class__.__name__}") if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free) @@ -626,7 +627,6 @@ def free_memory(memory_required, device, keep_loaded=[]): if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.info(f"Unloading {shift_model.model.model.__class__.__name__}") if shift_model.model_unload(memory_to_free): unloaded_model.append((i, shift_model))