mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 12:24:23 +08:00
Support lora loading with fp8, noise augment for control input
This commit is contained in:
parent
3efe90ba35
commit
3c8183ac65
@ -366,7 +366,7 @@ def create_network(
|
||||
)
|
||||
return network
|
||||
|
||||
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
||||
def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None):
|
||||
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
if state_dict is None:
|
||||
@ -380,15 +380,15 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
||||
|
||||
for layer, elems in updates.items():
|
||||
|
||||
if "lora_te" in layer:
|
||||
if transformer_only:
|
||||
continue
|
||||
else:
|
||||
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
||||
curr_layer = pipeline.text_encoder
|
||||
else:
|
||||
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
||||
curr_layer = pipeline.transformer
|
||||
# if "lora_te" in layer:
|
||||
# if transformer_only:
|
||||
# continue
|
||||
# else:
|
||||
# layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
||||
# curr_layer = pipeline.text_encoder
|
||||
#else:
|
||||
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
||||
curr_layer = transformer
|
||||
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
@ -421,7 +421,7 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
||||
else:
|
||||
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
||||
|
||||
return pipeline
|
||||
return transformer
|
||||
|
||||
# TODO: Refactor with merge_lora.
|
||||
def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
|
||||
|
||||
35
nodes.py
35
nodes.py
@ -341,6 +341,14 @@ class DownloadAndLoadCogVideoModel:
|
||||
|
||||
transformer = transformer.to(dtype).to(offload_device)
|
||||
|
||||
if lora is not None:
|
||||
if lora['strength'] > 0:
|
||||
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||
transformer = merge_lora(transformer, lora["path"], lora["strength"])
|
||||
else:
|
||||
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||
transformer = unmerge_lora(transformer, lora["path"], lora["strength"])
|
||||
|
||||
if block_edit is not None:
|
||||
transformer = remove_specific_blocks(transformer, block_edit)
|
||||
|
||||
@ -375,13 +383,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||
|
||||
if lora is not None:
|
||||
if lora['strength'] > 0:
|
||||
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||
pipe = merge_lora(pipe, lora["path"], lora["strength"])
|
||||
else:
|
||||
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||
pipe = unmerge_lora(pipe, lora["path"], lora["strength"])
|
||||
|
||||
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
@ -483,8 +485,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
transformer_config = json.load(f)
|
||||
|
||||
sd = load_torch_file(gguf_path)
|
||||
#for key, value in sd.items():
|
||||
# print(key, value.shape, value.dtype)
|
||||
|
||||
from . import mz_gguf_loader
|
||||
import importlib
|
||||
@ -530,7 +530,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
transformer.to(offload_device)
|
||||
else:
|
||||
transformer.to(device)
|
||||
|
||||
|
||||
|
||||
if fp8_fastmode:
|
||||
@ -1188,6 +1187,17 @@ class CogVideoXFunVid2VidSampler:
|
||||
# pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
|
||||
return (pipeline, {"samples": latents})
|
||||
|
||||
def add_noise_to_reference_video(image, ratio=None):
|
||||
if ratio is None:
|
||||
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
|
||||
sigma = torch.exp(sigma).to(image.dtype)
|
||||
else:
|
||||
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
|
||||
|
||||
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
||||
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
||||
image = image + image_noise
|
||||
return image
|
||||
|
||||
class CogVideoControlImageEncode:
|
||||
@classmethod
|
||||
@ -1197,6 +1207,7 @@ class CogVideoControlImageEncode:
|
||||
"control_video": ("IMAGE", ),
|
||||
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
|
||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -1205,7 +1216,7 @@ class CogVideoControlImageEncode:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, control_video, base_resolution, enable_tiling):
|
||||
def encode(self, pipeline, control_video, base_resolution, enable_tiling, noise_aug_strength=0.0563):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
@ -1239,6 +1250,8 @@ class CogVideoControlImageEncode:
|
||||
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
masked_image = control_video.to(device=device, dtype=vae.dtype)
|
||||
if noise_aug_strength > 0:
|
||||
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
|
||||
bs = 1
|
||||
new_mask_pixel_values = []
|
||||
for i in range(0, masked_image.shape[0], bs):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user