Speed up offloading using pinned memory. (#10526)

To enable this feature use: --fast pinned_memory
This commit is contained in:
comfyanonymous 2025-10-28 21:21:01 -07:00 committed by GitHub
parent 210f7a1ba5
commit 3fa7a5c04a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 1 deletions

View File

@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
PinnedMem = "pinned_memory"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))

View File

@ -1080,6 +1080,36 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
def pin_memory(tensor):
if PerformanceFeature.PinnedMem not in args.fast:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
return True
return False
def unpin_memory(tensor):
if PerformanceFeature.PinnedMem not in args.fast:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
return True
return False
def sage_attention_enabled():
return args.use_sage_attention

View File

@ -238,6 +238,7 @@ class ModelPatcher:
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@ -618,6 +619,21 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@ -683,6 +699,8 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
@ -713,7 +731,9 @@ class ModelPatcher:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@ -762,6 +782,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@ -857,6 +878,9 @@ class ModelPatcher:
memory_freed += module_mem
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed