Possible workaround for model patch order patch memory issue

This commit is contained in:
kijai 2025-03-12 00:05:42 +02:00
parent 51b9efe0a1
commit a4b9fd36da

View File

@ -220,12 +220,14 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
with self.use_ejected():
if lowvram_model_memory == 0:
full_load = True
else:
full_load = False
device_to = mm.get_torch_device()
full_load_override = getattr(self.model, "full_load_override", "auto")
if full_load_override in ["enabled", "disabled"]:
full_load = full_load_override == "enabled"
else:
full_load = lowvram_model_memory == 0
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
for k in self.object_patches:
@ -301,6 +303,7 @@ class PatchModelPatcherOrder:
return {"required": {
"model": ("MODEL",),
"patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}),
"full_load": (["enabled", "disabled", "auto"], {"default": "auto", "tooltip": "Disabling may help with memory issues when loading large models, when changing this you should probably force model reload to avoid issues!"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@ -308,9 +311,10 @@ class PatchModelPatcherOrder:
DESCRIPTION = "Patch the comfy patch_model function patching order, useful for torch.compile (used as object_patch) as it should come last if you want to use LoRAs with compile"
EXPERIMENTAL = True
def patch(self, model, patch_order):
def patch(self, model, patch_order, full_load):
comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {}
if patch_order == "weight_patch_first":
setattr(model.model, "full_load_override", full_load)
if patch_order == "weight_patch_first":
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
comfy.sd.load_lora_for_models = patched_load_lora_for_models
else: