mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
torch compile with LoRAs
This commit is contained in:
parent
3e11fff5f5
commit
8cfbbaf29e
@ -156,6 +156,7 @@ NODE_CONFIG = {
|
||||
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
|
||||
"TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"},
|
||||
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},
|
||||
"PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
113
nodes/nodes.py
113
nodes/nodes.py
@ -3,7 +3,10 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import json, re, os, io, time
|
||||
import re
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
|
||||
import model_management
|
||||
import folder_paths
|
||||
from nodes import MAX_RESOLUTION
|
||||
@ -2238,7 +2241,111 @@ class CheckpointLoaderKJ:
|
||||
|
||||
|
||||
return model, clip, vae
|
||||
import re
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
|
||||
original_load_lora_for_models = comfy.sd.load_lora_for_models
|
||||
|
||||
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
full_load = True
|
||||
else:
|
||||
full_load = False
|
||||
|
||||
if load_weights:
|
||||
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||
for k in self.object_patches:
|
||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
|
||||
return self.model
|
||||
|
||||
def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
patch_keys = list(model.object_patches_backup.keys())
|
||||
for k in patch_keys:
|
||||
#print("backing up object patch: ", k)
|
||||
comfy.utils.set_attr(model.model, k, model.object_patches_backup[k])
|
||||
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
#print(temp_object_patches_backup)
|
||||
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print("NOT LOADED {}".format(x))
|
||||
|
||||
if patch_keys:
|
||||
for k in patch_keys:
|
||||
if "diffusion_model." in k:
|
||||
# Remove the prefix to get the attribute path
|
||||
key = k.replace('diffusion_model.', '')
|
||||
attributes = key.split('.')
|
||||
# Start with the diffusion_model object
|
||||
block = model.get_model_object("diffusion_model")
|
||||
# Navigate through the attributes to get to the block
|
||||
for attr in attributes:
|
||||
if attr.isdigit():
|
||||
block = block[int(attr)]
|
||||
else:
|
||||
block = getattr(block, attr)
|
||||
# Compile the block
|
||||
compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor")
|
||||
# Add the compiled block back as an object patch
|
||||
model.add_object_patch(k, compiled_block)
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
|
||||
class PatchModelPatcherOrder:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCTIPTION = "Patch the comfy patch_model function patching order, useful for torch.compile (used as object_patch) as it should come last if you want to use LoRAs with compile"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, patch_order):
|
||||
comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {}
|
||||
if patch_order == "weight_patch_first":
|
||||
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
|
||||
comfy.sd.load_lora_for_models = patched_load_lora_for_models
|
||||
else:
|
||||
comfy.model_patcher.ModelPatcher.patch_model = original_patch_model
|
||||
comfy.model_patcher.ModelPatcher.unpatch_model = original_unpatch_model
|
||||
comfy.sd.load_lora_for_models = original_load_lora_for_models
|
||||
|
||||
return model,
|
||||
|
||||
class TorchCompileModelFluxAdvanced:
|
||||
def __init__(self):
|
||||
self._compiled = False
|
||||
@ -2280,11 +2387,11 @@ class TorchCompileModelFluxAdvanced:
|
||||
try:
|
||||
for i, block in enumerate(diffusion_model.double_blocks):
|
||||
if i in double_block_list:
|
||||
print("Compiling double_block", i)
|
||||
#print("Compiling double_block", i)
|
||||
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
||||
for i, block in enumerate(diffusion_model.single_blocks):
|
||||
if i in single_block_list:
|
||||
print("Compiling single block", i)
|
||||
#print("Compiling single block", i)
|
||||
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
||||
self._compiled = True
|
||||
except:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user