mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-14 16:34:36 +08:00
Account for dequantization and type-casts in offload costs
When measuring the cost of offload, identify weights that need a type change or dequantization and add the size of the conversion result to the offload cost. This is mutually exclusive with lowvram patches which already has a large conservative estimate and wont overlap the dequant cost so\ dont double count.
This commit is contained in:
parent
0833e3b801
commit
53bd09926c
@ -35,6 +35,7 @@ import comfy.model_management
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
from comfy.quant_ops import QuantizedTensor
|
||||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
|
||||||
|
|
||||||
@ -662,12 +663,18 @@ class ModelPatcher:
|
|||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
weight_key = "{}.weight".format(n)
|
def check_module_offload_mem(key):
|
||||||
bias_key = "{}.bias".format(n)
|
if key in self.patches:
|
||||||
if weight_key in self.patches:
|
return low_vram_patch_estimate_vram(self.model, key)
|
||||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||||
if bias_key in self.patches:
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
if model_dtype is None or weight is None:
|
||||||
|
return 0
|
||||||
|
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||||
|
return weight.numel() * model_dtype.itemsize
|
||||||
|
return 0
|
||||||
|
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||||
|
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user