diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5b1ccb824..62c4f291f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -662,12 +663,18 @@ class ModelPatcher: module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if weight_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) - if bias_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + def check_module_offload_mem(key): + if key in self.patches: + return low_vram_patch_estimate_vram(self.model, key) + model_dtype = getattr(self.model, "manual_cast_dtype", None) + weight, _, _ = get_key_weight(self.model, key) + if model_dtype is None or weight is None: + return 0 + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): + return weight.numel() * model_dtype.itemsize + return 0 + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) loading.append((module_offload_mem, module_mem, n, m, params)) return loading