torch compile with LoRAs

This commit is contained in:
kijai 2024-11-04 18:40:51 +02:00
parent 3e11fff5f5
commit 8cfbbaf29e
2 changed files with 111 additions and 3 deletions

View File

@ -156,6 +156,7 @@ NODE_CONFIG = {
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"}, "TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
"TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"}, "TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"},
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},
"PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"},
#instance diffusion #instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -3,7 +3,10 @@ import numpy as np
from PIL import Image from PIL import Image
import json, re, os, io, time import json, re, os, io, time
import re
import importlib import importlib
from contextlib import contextmanager
import model_management import model_management
import folder_paths import folder_paths
from nodes import MAX_RESOLUTION from nodes import MAX_RESOLUTION
@ -2238,7 +2241,111 @@ class CheckpointLoaderKJ:
return model, clip, vae return model, clip, vae
import re
import comfy.model_patcher
import comfy.utils
import comfy.sd
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
original_load_lora_for_models = comfy.sd.load_lora_for_models
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
if lowvram_model_memory == 0:
full_load = True
else:
full_load = False
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
return self.model
def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip):
patch_keys = list(model.object_patches_backup.keys())
for k in patch_keys:
#print("backing up object patch: ", k)
comfy.utils.set_attr(model.model, k, model.object_patches_backup[k])
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)
#print(temp_object_patches_backup)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print("NOT LOADED {}".format(x))
if patch_keys:
for k in patch_keys:
if "diffusion_model." in k:
# Remove the prefix to get the attribute path
key = k.replace('diffusion_model.', '')
attributes = key.split('.')
# Start with the diffusion_model object
block = model.get_model_object("diffusion_model")
# Navigate through the attributes to get to the block
for attr in attributes:
if attr.isdigit():
block = block[int(attr)]
else:
block = getattr(block, attr)
# Compile the block
compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor")
# Add the compiled block back as an object patch
model.add_object_patch(k, compiled_block)
return (new_modelpatcher, new_clip)
class PatchModelPatcherOrder:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
DESCTIPTION = "Patch the comfy patch_model function patching order, useful for torch.compile (used as object_patch) as it should come last if you want to use LoRAs with compile"
EXPERIMENTAL = True
def patch(self, model, patch_order):
comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {}
if patch_order == "weight_patch_first":
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
comfy.sd.load_lora_for_models = patched_load_lora_for_models
else:
comfy.model_patcher.ModelPatcher.patch_model = original_patch_model
comfy.model_patcher.ModelPatcher.unpatch_model = original_unpatch_model
comfy.sd.load_lora_for_models = original_load_lora_for_models
return model,
class TorchCompileModelFluxAdvanced: class TorchCompileModelFluxAdvanced:
def __init__(self): def __init__(self):
self._compiled = False self._compiled = False
@ -2280,11 +2387,11 @@ class TorchCompileModelFluxAdvanced:
try: try:
for i, block in enumerate(diffusion_model.double_blocks): for i, block in enumerate(diffusion_model.double_blocks):
if i in double_block_list: if i in double_block_list:
print("Compiling double_block", i) #print("Compiling double_block", i)
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
for i, block in enumerate(diffusion_model.single_blocks): for i, block in enumerate(diffusion_model.single_blocks):
if i in single_block_list: if i in single_block_list:
print("Compiling single block", i) #print("Compiling single block", i)
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
self._compiled = True self._compiled = True
except: except: