mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-15 00:44:25 +08:00
Speed up offloading using pinned memory. (#10526)
To enable this feature use: --fast pinned_memory
This commit is contained in:
parent
210f7a1ba5
commit
3fa7a5c04a
@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
|
PinnedMem = "pinned_memory"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
|
|||||||
@ -1080,6 +1080,36 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
def pin_memory(tensor):
|
||||||
|
if PerformanceFeature.PinnedMem not in args.fast:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unpin_memory(tensor):
|
||||||
|
if PerformanceFeature.PinnedMem not in args.fast:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
|||||||
@ -238,6 +238,7 @@ class ModelPatcher:
|
|||||||
self.force_cast_weights = False
|
self.force_cast_weights = False
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
self.parent = None
|
self.parent = None
|
||||||
|
self.pinned = set()
|
||||||
|
|
||||||
self.attachments: dict[str] = {}
|
self.attachments: dict[str] = {}
|
||||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||||
@ -618,6 +619,21 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
|
def pin_weight_to_device(self, key):
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
if comfy.model_management.pin_memory(weight):
|
||||||
|
self.pinned.add(key)
|
||||||
|
|
||||||
|
def unpin_weight(self, key):
|
||||||
|
if key in self.pinned:
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
comfy.model_management.unpin_memory(weight)
|
||||||
|
self.pinned.remove(key)
|
||||||
|
|
||||||
|
def unpin_all_weights(self):
|
||||||
|
for key in list(self.pinned):
|
||||||
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
@ -683,6 +699,8 @@ class ModelPatcher:
|
|||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
@ -713,7 +731,9 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
key = "{}.{}".format(n, param)
|
||||||
|
self.unpin_weight(key)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
@ -762,6 +782,7 @@ class ModelPatcher:
|
|||||||
self.eject_model()
|
self.eject_model()
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
self.unpin_all_weights()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
@ -857,6 +878,9 @@ class ModelPatcher:
|
|||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user