Support lora loading with fp8, noise augment for control input

This commit is contained in:
kijai 2024-10-04 12:48:32 +03:00
parent 3efe90ba35
commit 3c8183ac65
2 changed files with 35 additions and 22 deletions

View File

@ -366,7 +366,7 @@ def create_network(
)
return network
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None):
LORA_PREFIX_TRANSFORMER = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
if state_dict is None:
@ -380,15 +380,15 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
for layer, elems in updates.items():
if "lora_te" in layer:
if transformer_only:
continue
else:
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
curr_layer = pipeline.transformer
# if "lora_te" in layer:
# if transformer_only:
# continue
# else:
# layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
# curr_layer = pipeline.text_encoder
#else:
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
curr_layer = transformer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
@ -421,7 +421,7 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
else:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
return pipeline
return transformer
# TODO: Refactor with merge_lora.
def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):

View File

@ -341,6 +341,14 @@ class DownloadAndLoadCogVideoModel:
transformer = transformer.to(dtype).to(offload_device)
if lora is not None:
if lora['strength'] > 0:
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
transformer = merge_lora(transformer, lora["path"], lora["strength"])
else:
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
transformer = unmerge_lora(transformer, lora["path"], lora["strength"])
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
@ -375,13 +383,7 @@ class DownloadAndLoadCogVideoModel:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
if lora is not None:
if lora['strength'] > 0:
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
pipe = merge_lora(pipe, lora["path"], lora["strength"])
else:
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
pipe = unmerge_lora(pipe, lora["path"], lora["strength"])
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
@ -483,8 +485,6 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config = json.load(f)
sd = load_torch_file(gguf_path)
#for key, value in sd.items():
# print(key, value.shape, value.dtype)
from . import mz_gguf_loader
import importlib
@ -530,7 +530,6 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer.to(offload_device)
else:
transformer.to(device)
if fp8_fastmode:
@ -1188,6 +1187,17 @@ class CogVideoXFunVid2VidSampler:
# pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
return (pipeline, {"samples": latents})
def add_noise_to_reference_video(image, ratio=None):
if ratio is None:
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
sigma = torch.exp(sigma).to(image.dtype)
else:
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
image = image + image_noise
return image
class CogVideoControlImageEncode:
@classmethod
@ -1197,6 +1207,7 @@ class CogVideoControlImageEncode:
"control_video": ("IMAGE", ),
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
},
}
@ -1205,7 +1216,7 @@ class CogVideoControlImageEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, control_video, base_resolution, enable_tiling):
def encode(self, pipeline, control_video, base_resolution, enable_tiling, noise_aug_strength=0.0563):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -1239,6 +1250,8 @@ class CogVideoControlImageEncode:
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
masked_image = control_video.to(device=device, dtype=vae.dtype)
if noise_aug_strength > 0:
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
bs = 1
new_mask_pixel_values = []
for i in range(0, masked_image.shape[0], bs):