diff --git a/cogvideox_fun/lora_utils.py b/cogvideox_fun/lora_utils.py new file mode 100644 index 0000000..37b51fc --- /dev/null +++ b/cogvideox_fun/lora_utils.py @@ -0,0 +1,477 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# https://github.com/bmaltais/kohya_ss + +import hashlib +import math +import os +from collections import defaultdict +from io import BytesIO +from typing import List, Optional, Type, Union + +import safetensors.torch +import torch +import torch.utils.checkpoint +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from safetensors.torch import load_file +from transformers import T5EncoderModel + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x, *args, **kwargs): + weight_dtype = x.dtype + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x.to(self.lora_down.weight.dtype)) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +class LoRANetwork(torch.nn.Module): + TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"] + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + def __init__( + self, + text_encoder: Union[List[T5EncoderModel], T5EncoderModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + module_class: Type[object] = LoRAModule, + add_lora_in_attn_temporal: bool = False, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.dropout = dropout + + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}") + + # create module instances + def create_modules( + is_unet: bool, + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_TRANSFORMER + if is_unet + else self.LORA_PREFIX_TEXT_ENCODER + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if not add_lora_in_attn_temporal: + if "attn_temporal" in child_name: + continue + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1: + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if text_encoder is not None: + text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + info = self.load_state_dict(weights_sd, False) + return info + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + pass + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + text_encoder: Union[T5EncoderModel, List[T5EncoderModel]], + transformer, + neuron_dropout: Optional[float] = None, + add_lora_in_attn_temporal: bool = False, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + network = LoRANetwork( + text_encoder, + transformer, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + add_lora_in_attn_temporal=add_lora_in_attn_temporal, + varbose=True, + ) + return network + +def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False): + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + if state_dict is None: + state_dict = load_file(lora_path, device=device) + else: + state_dict = state_dict + updates = defaultdict(dict) + for key, value in state_dict.items(): + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + for layer, elems in updates.items(): + + if "lora_te" in layer: + if transformer_only: + continue + else: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") + curr_layer = pipeline.transformer + + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print('Error loading layer') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + weight_up = elems['lora_up.weight'].to(dtype) + weight_down = elems['lora_down.weight'].to(dtype) + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + curr_layer.weight.data = curr_layer.weight.data.to(device) + if len(weight_up.shape) == 4: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), + weight_down.squeeze(3).squeeze(2)).unsqueeze( + 2).unsqueeze(3) + else: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + + return pipeline + +# TODO: Refactor with merge_lora. +def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32): + """Unmerge state_dict in LoRANetwork from the pipeline in diffusers.""" + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + state_dict = load_file(lora_path, device=device) + + updates = defaultdict(dict) + for key, value in state_dict.items(): + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + for layer, elems in updates.items(): + + if "lora_te" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.transformer + + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print('Error loading layer') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + weight_up = elems['lora_up.weight'].to(dtype) + weight_down = elems['lora_down.weight'].to(dtype) + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + curr_layer.weight.data = curr_layer.weight.data.to(device) + if len(weight_up.shape) == 4: + curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), + weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) + + return pipeline \ No newline at end of file diff --git a/nodes.py b/nodes.py index 692c635..560aee4 100644 --- a/nodes.py +++ b/nodes.py @@ -47,6 +47,7 @@ scheduler_mapping = { from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from .pipeline_cogvideox import CogVideoXPipeline from contextlib import nullcontext +from pathlib import Path from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFunPAB @@ -54,6 +55,7 @@ from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as Autoenco from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control +from .cogvideox_fun.lora_utils import merge_lora, unmerge_lora from PIL import Image import numpy as np import json @@ -204,6 +206,34 @@ class CogVideoTransformerEdit: blocks_to_remove = [int(x.strip()) for x in remove_blocks.split(',')] log.info(f"Blocks selected for removal: {blocks_to_remove}") return (blocks_to_remove,) + + +folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras")) + +class CogVideoLoraSelect: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora": (folder_paths.get_filename_list("cogvideox_loras"), + {"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), + }, + } + + RETURN_TYPES = ("COGLORA",) + RETURN_NAMES = ("lora", ) + FUNCTION = "getlorapath" + CATEGORY = "CogVideoWrapper" + + def getlorapath(self, lora, strength): + + cog_lora = { + "path": folder_paths.get_full_path("cogvideox_loras", lora), + "strength": strength + } + + return (cog_lora,) class DownloadAndLoadCogVideoModel: @classmethod @@ -235,6 +265,7 @@ class DownloadAndLoadCogVideoModel: "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), "pab_config": ("PAB_CONFIG", {"default": None}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), + "lora": ("COGLORA", {"default": None}), } } @@ -243,7 +274,7 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None): + def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None): check_diffusers_version() @@ -344,6 +375,14 @@ class DownloadAndLoadCogVideoModel: vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) + if lora is not None: + if lora['strength'] > 0: + logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}") + pipe = merge_lora(pipe, lora["path"], lora["strength"]) + else: + logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}") + pipe = unmerge_lora(pipe, lora["path"], lora["strength"]) + if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() @@ -1190,6 +1229,7 @@ class CogVideoControlImageEncode: closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) height, width = [int(x / 16) * 16 for x in closest_size] + log.info(f"Closest bucket size: {width}x{height}") video_length = int((B - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if B != 1 else 1 input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width)) @@ -1294,9 +1334,6 @@ class CogVideoXFunControlSampler: autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: - # for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])): - # pipeline = merge_lora(pipeline, _lora_path, _lora_weight) - common_params = { "prompt_embeds": positive.to(dtype).to(device), "negative_prompt_embeds": negative.to(dtype).to(device), @@ -1320,8 +1357,6 @@ class CogVideoXFunControlSampler: scheduler_name=scheduler ) - # for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])): - # pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight) return (pipeline, {"samples": latents}) NODE_CLASS_MAPPINGS = { @@ -1338,7 +1373,8 @@ NODE_CLASS_MAPPINGS = { "DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel, "CogVideoPABConfig": CogVideoPABConfig, "CogVideoTransformerEdit": CogVideoTransformerEdit, - "CogVideoControlImageEncode": CogVideoControlImageEncode + "CogVideoControlImageEncode": CogVideoControlImageEncode, + "CogVideoLoraSelect": CogVideoLoraSelect } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1354,5 +1390,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model", "CogVideoPABConfig": "CogVideo PABConfig", "CogVideoTransformerEdit": "CogVideo TransformerEdit", - "CogVideoControlImageEncode": "CogVideo Control ImageEncode" + "CogVideoControlImageEncode": "CogVideo Control ImageEncode", + "CogVideoLoraSelect": "CogVideo LoraSelect" }