Support lora loading with fp8, noise augment for control input

This commit is contained in:
kijai 2024-10-04 12:48:32 +03:00
parent 3efe90ba35
commit 3c8183ac65
2 changed files with 35 additions and 22 deletions

View File

@ -366,7 +366,7 @@ def create_network(
) )
return network return network
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False): def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None):
LORA_PREFIX_TRANSFORMER = "lora_unet" LORA_PREFIX_TRANSFORMER = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" LORA_PREFIX_TEXT_ENCODER = "lora_te"
if state_dict is None: if state_dict is None:
@ -380,15 +380,15 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
for layer, elems in updates.items(): for layer, elems in updates.items():
if "lora_te" in layer: # if "lora_te" in layer:
if transformer_only: # if transformer_only:
continue # continue
else: # else:
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") # layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = pipeline.text_encoder # curr_layer = pipeline.text_encoder
else: #else:
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
curr_layer = pipeline.transformer curr_layer = transformer
temp_name = layer_infos.pop(0) temp_name = layer_infos.pop(0)
while len(layer_infos) > -1: while len(layer_infos) > -1:
@ -421,7 +421,7 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
else: else:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
return pipeline return transformer
# TODO: Refactor with merge_lora. # TODO: Refactor with merge_lora.
def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32): def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):

View File

@ -341,6 +341,14 @@ class DownloadAndLoadCogVideoModel:
transformer = transformer.to(dtype).to(offload_device) transformer = transformer.to(dtype).to(offload_device)
if lora is not None:
if lora['strength'] > 0:
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
transformer = merge_lora(transformer, lora["path"], lora["strength"])
else:
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
transformer = unmerge_lora(transformer, lora["path"], lora["strength"])
if block_edit is not None: if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit) transformer = remove_specific_blocks(transformer, block_edit)
@ -375,13 +383,7 @@ class DownloadAndLoadCogVideoModel:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
if lora is not None:
if lora['strength'] > 0:
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
pipe = merge_lora(pipe, lora["path"], lora["strength"])
else:
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
pipe = unmerge_lora(pipe, lora["path"], lora["strength"])
if enable_sequential_cpu_offload: if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
@ -483,8 +485,6 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config = json.load(f) transformer_config = json.load(f)
sd = load_torch_file(gguf_path) sd = load_torch_file(gguf_path)
#for key, value in sd.items():
# print(key, value.shape, value.dtype)
from . import mz_gguf_loader from . import mz_gguf_loader
import importlib import importlib
@ -530,7 +530,6 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer.to(offload_device) transformer.to(offload_device)
else: else:
transformer.to(device) transformer.to(device)
if fp8_fastmode: if fp8_fastmode:
@ -1188,6 +1187,17 @@ class CogVideoXFunVid2VidSampler:
# pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight) # pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
return (pipeline, {"samples": latents}) return (pipeline, {"samples": latents})
def add_noise_to_reference_video(image, ratio=None):
if ratio is None:
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
sigma = torch.exp(sigma).to(image.dtype)
else:
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
image = image + image_noise
return image
class CogVideoControlImageEncode: class CogVideoControlImageEncode:
@classmethod @classmethod
@ -1197,6 +1207,7 @@ class CogVideoControlImageEncode:
"control_video": ("IMAGE", ), "control_video": ("IMAGE", ),
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}), "base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
}, },
} }
@ -1205,7 +1216,7 @@ class CogVideoControlImageEncode:
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, control_video, base_resolution, enable_tiling): def encode(self, pipeline, control_video, base_resolution, enable_tiling, noise_aug_strength=0.0563):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
@ -1239,6 +1250,8 @@ class CogVideoControlImageEncode:
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
masked_image = control_video.to(device=device, dtype=vae.dtype) masked_image = control_video.to(device=device, dtype=vae.dtype)
if noise_aug_strength > 0:
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
bs = 1 bs = 1
new_mask_pixel_values = [] new_mask_pixel_values = []
for i in range(0, masked_image.shape[0], bs): for i in range(0, masked_image.shape[0], bs):