From 8cfbbaf29e0ee44ab8ec0155d65e6f510549c599 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 4 Nov 2024 18:40:51 +0200 Subject: [PATCH] torch compile with LoRAs --- __init__.py | 1 + nodes/nodes.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/__init__.py b/__init__.py index 75e6c55..02bd738 100644 --- a/__init__.py +++ b/__init__.py @@ -156,6 +156,7 @@ NODE_CONFIG = { "TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"}, "TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"}, "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, + "PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index a1e0cd6..a00a94c 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -3,7 +3,10 @@ import numpy as np from PIL import Image import json, re, os, io, time +import re import importlib +from contextlib import contextmanager + import model_management import folder_paths from nodes import MAX_RESOLUTION @@ -2238,7 +2241,111 @@ class CheckpointLoaderKJ: return model, clip, vae -import re + +import comfy.model_patcher +import comfy.utils +import comfy.sd +original_patch_model = comfy.model_patcher.ModelPatcher.patch_model +original_load_lora_for_models = comfy.sd.load_lora_for_models + +def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + + if lowvram_model_memory == 0: + full_load = True + else: + full_load = False + + if load_weights: + self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) + for k in self.object_patches: + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old + + return self.model + +def patched_load_lora_for_models(model, clip, lora, strength_model, strength_clip): + + patch_keys = list(model.object_patches_backup.keys()) + for k in patch_keys: + #print("backing up object patch: ", k) + comfy.utils.set_attr(model.model, k, model.object_patches_backup[k]) + + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + loaded = comfy.lora.load_lora(lora, key_map) + #print(temp_object_patches_backup) + + if model is not None: + new_modelpatcher = model.clone() + k = new_modelpatcher.add_patches(loaded, strength_model) + else: + k = () + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + k1 = new_clip.add_patches(loaded, strength_clip) + else: + k1 = () + new_clip = None + k = set(k) + k1 = set(k1) + for x in loaded: + if (x not in k) and (x not in k1): + print("NOT LOADED {}".format(x)) + + if patch_keys: + for k in patch_keys: + if "diffusion_model." in k: + # Remove the prefix to get the attribute path + key = k.replace('diffusion_model.', '') + attributes = key.split('.') + # Start with the diffusion_model object + block = model.get_model_object("diffusion_model") + # Navigate through the attributes to get to the block + for attr in attributes: + if attr.isdigit(): + block = block[int(attr)] + else: + block = getattr(block, attr) + # Compile the block + compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor") + # Add the compiled block back as an object patch + model.add_object_patch(k, compiled_block) + return (new_modelpatcher, new_clip) + + +class PatchModelPatcherOrder: + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "KJNodes/experimental" + DESCTIPTION = "Patch the comfy patch_model function patching order, useful for torch.compile (used as object_patch) as it should come last if you want to use LoRAs with compile" + EXPERIMENTAL = True + + def patch(self, model, patch_order): + comfy.model_patcher.ModelPatcher.temp_object_patches_backup = {} + if patch_order == "weight_patch_first": + comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model + comfy.sd.load_lora_for_models = patched_load_lora_for_models + else: + comfy.model_patcher.ModelPatcher.patch_model = original_patch_model + comfy.model_patcher.ModelPatcher.unpatch_model = original_unpatch_model + comfy.sd.load_lora_for_models = original_load_lora_for_models + + return model, + class TorchCompileModelFluxAdvanced: def __init__(self): self._compiled = False @@ -2280,11 +2387,11 @@ class TorchCompileModelFluxAdvanced: try: for i, block in enumerate(diffusion_model.double_blocks): if i in double_block_list: - print("Compiling double_block", i) + #print("Compiling double_block", i) m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) for i, block in enumerate(diffusion_model.single_blocks): if i in single_block_list: - print("Compiling single block", i) + #print("Compiling single block", i) m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) self._compiled = True except: