mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-04 10:46:43 +08:00
Support lora loading with fp8, noise augment for control input
This commit is contained in:
parent
3efe90ba35
commit
3c8183ac65
@ -366,7 +366,7 @@ def create_network(
|
|||||||
)
|
)
|
||||||
return network
|
return network
|
||||||
|
|
||||||
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None):
|
||||||
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
||||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
@ -380,15 +380,15 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
|||||||
|
|
||||||
for layer, elems in updates.items():
|
for layer, elems in updates.items():
|
||||||
|
|
||||||
if "lora_te" in layer:
|
# if "lora_te" in layer:
|
||||||
if transformer_only:
|
# if transformer_only:
|
||||||
continue
|
# continue
|
||||||
else:
|
# else:
|
||||||
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
# layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
||||||
curr_layer = pipeline.text_encoder
|
# curr_layer = pipeline.text_encoder
|
||||||
else:
|
#else:
|
||||||
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
||||||
curr_layer = pipeline.transformer
|
curr_layer = transformer
|
||||||
|
|
||||||
temp_name = layer_infos.pop(0)
|
temp_name = layer_infos.pop(0)
|
||||||
while len(layer_infos) > -1:
|
while len(layer_infos) > -1:
|
||||||
@ -421,7 +421,7 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
|||||||
else:
|
else:
|
||||||
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
||||||
|
|
||||||
return pipeline
|
return transformer
|
||||||
|
|
||||||
# TODO: Refactor with merge_lora.
|
# TODO: Refactor with merge_lora.
|
||||||
def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
|
def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
|
||||||
|
|||||||
35
nodes.py
35
nodes.py
@ -341,6 +341,14 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
transformer = transformer.to(dtype).to(offload_device)
|
transformer = transformer.to(dtype).to(offload_device)
|
||||||
|
|
||||||
|
if lora is not None:
|
||||||
|
if lora['strength'] > 0:
|
||||||
|
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||||
|
transformer = merge_lora(transformer, lora["path"], lora["strength"])
|
||||||
|
else:
|
||||||
|
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||||
|
transformer = unmerge_lora(transformer, lora["path"], lora["strength"])
|
||||||
|
|
||||||
if block_edit is not None:
|
if block_edit is not None:
|
||||||
transformer = remove_specific_blocks(transformer, block_edit)
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
|
|
||||||
@ -375,13 +383,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||||
|
|
||||||
if lora is not None:
|
|
||||||
if lora['strength'] > 0:
|
|
||||||
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
|
||||||
pipe = merge_lora(pipe, lora["path"], lora["strength"])
|
|
||||||
else:
|
|
||||||
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
|
|
||||||
pipe = unmerge_lora(pipe, lora["path"], lora["strength"])
|
|
||||||
|
|
||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
@ -483,8 +485,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
transformer_config = json.load(f)
|
transformer_config = json.load(f)
|
||||||
|
|
||||||
sd = load_torch_file(gguf_path)
|
sd = load_torch_file(gguf_path)
|
||||||
#for key, value in sd.items():
|
|
||||||
# print(key, value.shape, value.dtype)
|
|
||||||
|
|
||||||
from . import mz_gguf_loader
|
from . import mz_gguf_loader
|
||||||
import importlib
|
import importlib
|
||||||
@ -530,7 +530,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
transformer.to(offload_device)
|
transformer.to(offload_device)
|
||||||
else:
|
else:
|
||||||
transformer.to(device)
|
transformer.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if fp8_fastmode:
|
if fp8_fastmode:
|
||||||
@ -1188,6 +1187,17 @@ class CogVideoXFunVid2VidSampler:
|
|||||||
# pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
|
# pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
|
||||||
return (pipeline, {"samples": latents})
|
return (pipeline, {"samples": latents})
|
||||||
|
|
||||||
|
def add_noise_to_reference_video(image, ratio=None):
|
||||||
|
if ratio is None:
|
||||||
|
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
|
||||||
|
sigma = torch.exp(sigma).to(image.dtype)
|
||||||
|
else:
|
||||||
|
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
|
||||||
|
|
||||||
|
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
||||||
|
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
||||||
|
image = image + image_noise
|
||||||
|
return image
|
||||||
|
|
||||||
class CogVideoControlImageEncode:
|
class CogVideoControlImageEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1197,6 +1207,7 @@ class CogVideoControlImageEncode:
|
|||||||
"control_video": ("IMAGE", ),
|
"control_video": ("IMAGE", ),
|
||||||
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
|
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
|
||||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||||
|
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1205,7 +1216,7 @@ class CogVideoControlImageEncode:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, control_video, base_resolution, enable_tiling):
|
def encode(self, pipeline, control_video, base_resolution, enable_tiling, noise_aug_strength=0.0563):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
|
|
||||||
@ -1239,6 +1250,8 @@ class CogVideoControlImageEncode:
|
|||||||
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
||||||
|
|
||||||
masked_image = control_video.to(device=device, dtype=vae.dtype)
|
masked_image = control_video.to(device=device, dtype=vae.dtype)
|
||||||
|
if noise_aug_strength > 0:
|
||||||
|
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
|
||||||
bs = 1
|
bs = 1
|
||||||
new_mask_pixel_values = []
|
new_mask_pixel_values = []
|
||||||
for i in range(0, masked_image.shape[0], bs):
|
for i in range(0, masked_image.shape[0], bs):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user