Update model_optimization_nodes.py

This commit is contained in:
kijai 2025-03-11 19:46:10 +02:00
parent 63966e3483
commit 51b9efe0a1

View File

@ -218,6 +218,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
return (model,) return (model,)
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
with self.use_ejected():
if lowvram_model_memory == 0: if lowvram_model_memory == 0:
full_load = True full_load = True
@ -225,14 +226,14 @@ def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weigh
full_load = False full_load = False
device_to = mm.get_torch_device() device_to = mm.get_torch_device()
load_weights = True
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
for k in self.object_patches: for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup: if k not in self.object_patches_backup:
self.object_patches_backup[k] = old self.object_patches_backup[k] = old
self.inject_model()
return self.model return self.model
def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip): def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip):
@ -248,6 +249,7 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
if clip is not None: if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map) loaded = comfy.lora.load_lora(lora, key_map)
#print(temp_object_patches_backup) #print(temp_object_patches_backup)