From 3c8183ac65d21f6c70af6bab8d5e84d65d7bf27e Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:48:32 +0300 Subject: [PATCH] Support lora loading with fp8, noise augment for control input --- cogvideox_fun/lora_utils.py | 22 +++++++++++----------- nodes.py | 35 ++++++++++++++++++++++++----------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/cogvideox_fun/lora_utils.py b/cogvideox_fun/lora_utils.py index 37b51fc..42038a5 100644 --- a/cogvideox_fun/lora_utils.py +++ b/cogvideox_fun/lora_utils.py @@ -366,7 +366,7 @@ def create_network( ) return network -def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False): +def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None): LORA_PREFIX_TRANSFORMER = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" if state_dict is None: @@ -380,15 +380,15 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3 for layer, elems in updates.items(): - if "lora_te" in layer: - if transformer_only: - continue - else: - layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") - curr_layer = pipeline.text_encoder - else: - layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") - curr_layer = pipeline.transformer + # if "lora_te" in layer: + # if transformer_only: + # continue + # else: + # layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + # curr_layer = pipeline.text_encoder + #else: + layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") + curr_layer = transformer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: @@ -421,7 +421,7 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3 else: curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) - return pipeline + return transformer # TODO: Refactor with merge_lora. def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32): diff --git a/nodes.py b/nodes.py index 560aee4..24897a5 100644 --- a/nodes.py +++ b/nodes.py @@ -341,6 +341,14 @@ class DownloadAndLoadCogVideoModel: transformer = transformer.to(dtype).to(offload_device) + if lora is not None: + if lora['strength'] > 0: + logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}") + transformer = merge_lora(transformer, lora["path"], lora["strength"]) + else: + logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}") + transformer = unmerge_lora(transformer, lora["path"], lora["strength"]) + if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit) @@ -375,13 +383,7 @@ class DownloadAndLoadCogVideoModel: vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) - if lora is not None: - if lora['strength'] > 0: - logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}") - pipe = merge_lora(pipe, lora["path"], lora["strength"]) - else: - logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}") - pipe = unmerge_lora(pipe, lora["path"], lora["strength"]) + if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() @@ -483,8 +485,6 @@ class DownloadAndLoadCogVideoGGUFModel: transformer_config = json.load(f) sd = load_torch_file(gguf_path) - #for key, value in sd.items(): - # print(key, value.shape, value.dtype) from . import mz_gguf_loader import importlib @@ -530,7 +530,6 @@ class DownloadAndLoadCogVideoGGUFModel: transformer.to(offload_device) else: transformer.to(device) - if fp8_fastmode: @@ -1188,6 +1187,17 @@ class CogVideoXFunVid2VidSampler: # pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight) return (pipeline, {"samples": latents}) +def add_noise_to_reference_video(image, ratio=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image class CogVideoControlImageEncode: @classmethod @@ -1197,6 +1207,7 @@ class CogVideoControlImageEncode: "control_video": ("IMAGE", ), "base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), + "noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}), }, } @@ -1205,7 +1216,7 @@ class CogVideoControlImageEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, control_video, base_resolution, enable_tiling): + def encode(self, pipeline, control_video, base_resolution, enable_tiling, noise_aug_strength=0.0563): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -1239,6 +1250,8 @@ class CogVideoControlImageEncode: control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) masked_image = control_video.to(device=device, dtype=vae.dtype) + if noise_aug_strength > 0: + masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) bs = 1 new_mask_pixel_values = [] for i in range(0, masked_image.shape[0], bs):