mp: use look-ahead actuals for stream offload VRAM calculation (#11096)

TIL that the WAN TE has a 2GB weight followed by 16MB as the next size
down. This means that team 8GB VRAM would fully offload the TE in async
offload mode as it just multiplied this giant size my the num streams.

Do the more complex logic of summing up the upcoming to-load weight
sizes to avoid triple counting this massive weight.

partial unload does the converse of recording the NS most recent
unloads as they go.
This commit is contained in:
rattus 2025-12-04 14:28:44 +10:00 committed by GitHub
parent ea17add3c6
commit 6be85c7920
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -699,12 +699,12 @@ class ModelPatcher:
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
for x in loading:
for i, x in enumerate(loading):
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
@ -876,14 +876,18 @@ class ModelPatcher:
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
if len(unload_list) > 0:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
module_offload_mem, module_mem, n, m, params = unload
potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -935,6 +939,8 @@ class ModelPatcher:
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
logging.debug("freed {}".format(n))
for param in params: