From 519c9411653df99761053c30e101816e0ca3c24b Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:28:45 +1000 Subject: [PATCH] Prs/lora reservations (reduce massive Lora reservations especially on Flux2) (#11069) * mp: only count the offload cost of math once This was previously bundling the combined weight storage and computation cost * ops: put all post async transfer compute on the main stream Some models have massive weights that need either complex dequantization or lora patching. Don't do these patchings on the offload stream, instead do them on the main stream to syncrhonize the potentially large vram spikes for these compute processes. This avoids having to assume a worst case scenario of multiple offload streams all spiking VRAM is parallel with whatever the main stream is doing. --- comfy/model_patcher.py | 4 ++-- comfy/ops.py | 39 ++++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..df2d8e827 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -704,7 +704,7 @@ class ModelPatcher: lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -883,7 +883,7 @@ class ModelPatcher: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem + potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754..eae434e68 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if s.bias is not None: bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) - if bias_has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + comfy.model_management.sync_stream(device, offload_stream) + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) if weight_has_function or weight.dtype != dtype: - with wf_context: - weight = weight.to(dtype=dtype) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - for f in s.weight_function: - weight = f(weight) + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) - comfy.model_management.sync_stream(device, offload_stream) if offloadable: - return weight, bias, offload_stream + return weight, bias, (offload_stream, weight_a, bias_a) else: #Legacy function signature return weight, bias @@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return - if weight is not None: - device = weight.device + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device else: - if bias is None: + if bias_a is None: return - device = bias.device - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + device = bias_a.device + os.wait_stream(comfy.model_management.current_stream(device)) class CastWeightBiasOp: