mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
Speed up offloading using pinned memory. (#10526)
To enable this feature use: --fast pinned_memory
This commit is contained in:
parent
210f7a1ba5
commit
3fa7a5c04a
@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
PinnedMem = "pinned_memory"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
|
||||
@ -1080,6 +1080,36 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
def pin_memory(tensor):
|
||||
if PerformanceFeature.PinnedMem not in args.fast:
|
||||
return False
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
if PerformanceFeature.PinnedMem not in args.fast:
|
||||
return False
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
|
||||
@ -238,6 +238,7 @@ class ModelPatcher:
|
||||
self.force_cast_weights = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.parent = None
|
||||
self.pinned = set()
|
||||
|
||||
self.attachments: dict[str] = {}
|
||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||
@ -618,6 +619,21 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if comfy.model_management.pin_memory(weight):
|
||||
self.pinned.add(key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
if key in self.pinned:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
comfy.model_management.unpin_memory(weight)
|
||||
self.pinned.remove(key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
@ -683,6 +699,8 @@ class ModelPatcher:
|
||||
patch_counter += 1
|
||||
|
||||
cast_weight = True
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
@ -713,7 +731,9 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
@ -762,6 +782,7 @@ class ModelPatcher:
|
||||
self.eject_model()
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
@ -857,6 +878,9 @@ class ModelPatcher:
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user