diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..df2d8e827 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -704,7 +704,7 @@ class ModelPatcher: lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -883,7 +883,7 @@ class ModelPatcher: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem + potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754..eae434e68 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if s.bias is not None: bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) - if bias_has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + comfy.model_management.sync_stream(device, offload_stream) + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) if weight_has_function or weight.dtype != dtype: - with wf_context: - weight = weight.to(dtype=dtype) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - for f in s.weight_function: - weight = f(weight) + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) - comfy.model_management.sync_stream(device, offload_stream) if offloadable: - return weight, bias, offload_stream + return weight, bias, (offload_stream, weight_a, bias_a) else: #Legacy function signature return weight, bias @@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return - if weight is not None: - device = weight.device + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device else: - if bias is None: + if bias_a is None: return - device = bias.device - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + device = bias_a.device + os.wait_stream(comfy.model_management.current_stream(device)) class CastWeightBiasOp: