mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 10:24:24 +08:00
testing Tora for I2V
This commit is contained in:
parent
81f8ca676e
commit
a654821515
@ -23,7 +23,7 @@ import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import is_torch_version, logging
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
@ -37,11 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGEATTN_IS_AVAVILABLE = True
|
||||
SAGEATTN_IS_AVAILABLE = True
|
||||
logger.info("Using sageattn")
|
||||
except:
|
||||
logger.info("sageattn not found, using sdpa")
|
||||
SAGEATTN_IS_AVAVILABLE = False
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
@ -97,7 +97,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
if SAGEATTN_IS_AVAILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
@ -171,7 +171,7 @@ class FusedCogVideoXAttnProcessor2_0:
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
if SAGEATTN_IS_AVAILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
|
||||
1324
examples/cogvideox_5b_Tora_I2V_testing_01.json
Normal file
1324
examples/cogvideox_5b_Tora_I2V_testing_01.json
Normal file
File diff suppressed because it is too large
Load Diff
241
nodes.py
241
nodes.py
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
@ -420,38 +421,6 @@ class DownloadAndLoadCogVideoModel:
|
||||
fuse_qkv_projections=True if pab_config is None else False,
|
||||
)
|
||||
|
||||
if "Tora" in model:
|
||||
import torch.nn as nn
|
||||
from .tora.traj_module import MGF
|
||||
|
||||
hidden_size = 3072
|
||||
num_layers = transformer.num_layers
|
||||
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
||||
fuser_sd = load_torch_file(os.path.join(base_path, "fuser", "fuser.safetensors"))
|
||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
||||
for module in transformer.fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(torch.float16)
|
||||
del fuser_sd
|
||||
|
||||
from .tora.traj_module import TrajExtractor
|
||||
traj_extractor = TrajExtractor(
|
||||
vae_downsize=(4, 8, 8),
|
||||
patch_size=2,
|
||||
nums_rb=2,
|
||||
cin=vae.config.latent_channels,
|
||||
channels=[128] * transformer.num_layers,
|
||||
sk=True,
|
||||
use_conv=False,
|
||||
)
|
||||
|
||||
traj_sd = load_torch_file(os.path.join(base_path, "traj_extractor", "traj_extractor.safetensors"))
|
||||
traj_extractor.load_state_dict(traj_sd)
|
||||
traj_extractor.to(torch.float32).to(device)
|
||||
|
||||
pipe.traj_extractor = traj_extractor
|
||||
|
||||
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
"dtype": dtype,
|
||||
@ -622,63 +591,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
vae.load_state_dict(vae_sd)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||
|
||||
if "Tora" in model:
|
||||
import torch.nn as nn
|
||||
from .tora.traj_module import MGF
|
||||
|
||||
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
||||
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
|
||||
if not os.path.exists(fuser_path):
|
||||
log.info(f"Downloading Fuser model to: {fuser_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="kijai/CogVideoX-5b-Tora",
|
||||
allow_patterns=["*fuser.safetensors*"],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
hidden_size = 3072
|
||||
num_layers = transformer.num_layers
|
||||
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
||||
|
||||
fuser_sd = load_torch_file(fuser_path)
|
||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
||||
for module in transformer.fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(torch.float16)
|
||||
del fuser_sd
|
||||
|
||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
||||
if not os.path.exists(traj_extractor_path):
|
||||
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="kijai/CogVideoX-5b-Tora",
|
||||
allow_patterns=["*traj_extractor.safetensors*"],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
from .tora.traj_module import TrajExtractor
|
||||
traj_extractor = TrajExtractor(
|
||||
vae_downsize=(4, 8, 8),
|
||||
patch_size=2,
|
||||
nums_rb=2,
|
||||
cin=vae.config.latent_channels,
|
||||
channels=[128] * transformer.num_layers,
|
||||
sk=True,
|
||||
use_conv=False,
|
||||
)
|
||||
|
||||
traj_sd = load_torch_file(traj_extractor_path)
|
||||
traj_extractor.load_state_dict(traj_sd)
|
||||
traj_extractor.to(torch.float32).to(device)
|
||||
|
||||
pipe.traj_extractor = traj_extractor
|
||||
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
@ -694,6 +606,114 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
|
||||
return (pipeline,)
|
||||
|
||||
class DownloadAndLoadToraModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": (
|
||||
[
|
||||
"kijai/CogVideoX-5b-Tora",
|
||||
],
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TORAMODEL",)
|
||||
RETURN_NAMES = ("tora_model", )
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "Downloads and loads the the Tora model from Huggingface to 'ComfyUI/models/CogVideo/CogVideoX-5b-Tora'"
|
||||
|
||||
def loadmodel(self, model):
|
||||
|
||||
check_diffusers_version()
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
download_path = folder_paths.get_folder_paths("CogVideo")[0]
|
||||
|
||||
from .tora.traj_module import MGF
|
||||
|
||||
try:
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
is_accelerate_available = True
|
||||
except:
|
||||
is_accelerate_available = False
|
||||
pass
|
||||
|
||||
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
||||
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
|
||||
if not os.path.exists(fuser_path):
|
||||
log.info(f"Downloading Fuser model to: {fuser_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id=model,
|
||||
allow_patterns=["*fuser.safetensors*"],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
hidden_size = 3072
|
||||
num_layers = 42
|
||||
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
||||
|
||||
fuser_sd = load_torch_file(fuser_path)
|
||||
if is_accelerate_available:
|
||||
for key in fuser_sd:
|
||||
set_module_tensor_to_device(fuser_list, key, dtype=torch.float16, device=device, value=fuser_sd[key])
|
||||
else:
|
||||
fuser_list.load_state_dict(fuser_sd)
|
||||
for module in fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(torch.float16).to(device)
|
||||
del fuser_sd
|
||||
|
||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
||||
if not os.path.exists(traj_extractor_path):
|
||||
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="kijai/CogVideoX-5b-Tora",
|
||||
allow_patterns=["*traj_extractor.safetensors*"],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
from .tora.traj_module import TrajExtractor
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
traj_extractor = TrajExtractor(
|
||||
vae_downsize=(4, 8, 8),
|
||||
patch_size=2,
|
||||
nums_rb=2,
|
||||
cin=16,
|
||||
channels=[128] * 42,
|
||||
sk=True,
|
||||
use_conv=False,
|
||||
)
|
||||
|
||||
traj_sd = load_torch_file(traj_extractor_path)
|
||||
if is_accelerate_available:
|
||||
for key in traj_sd:
|
||||
set_module_tensor_to_device(traj_extractor, key, dtype=torch.float32, device=device, value=traj_sd[key])
|
||||
else:
|
||||
traj_extractor.load_state_dict(traj_sd)
|
||||
traj_extractor.to(torch.float32).to(device)
|
||||
|
||||
toramodel = {
|
||||
"fuser_list": fuser_list,
|
||||
"traj_extractor": traj_extractor,
|
||||
}
|
||||
|
||||
return (toramodel,)
|
||||
|
||||
class DownloadAndLoadCogVideoControlNet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1060,11 +1080,14 @@ class ToraEncodeTrajectory:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"tora_model": ("TORAMODEL",),
|
||||
"coordinates": ("STRING", {"forceInput": True}),
|
||||
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||
"num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -1073,13 +1096,12 @@ class ToraEncodeTrajectory:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, width, height, num_frames, coordinates, strength):
|
||||
def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model):
|
||||
check_diffusers_version()
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
traj_extractor = pipeline["pipe"].traj_extractor
|
||||
vae = pipeline["pipe"].vae
|
||||
vae.enable_slicing()
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
@ -1108,22 +1130,33 @@ class ToraEncodeTrajectory:
|
||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
vae.to(offload_device)
|
||||
|
||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
||||
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
|
||||
video_flow_features = torch.stack(video_flow_features)
|
||||
|
||||
video_flow_features = video_flow_features * strength
|
||||
|
||||
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||
|
||||
return (video_flow_features, video_flow_image.cpu().float())
|
||||
tora = {
|
||||
"video_flow_features" : video_flow_features,
|
||||
"start_percent" : start_percent,
|
||||
"end_percent" : end_percent,
|
||||
"traj_extractor" : tora_model["traj_extractor"],
|
||||
"fuser_list" : tora_model["fuser_list"],
|
||||
}
|
||||
|
||||
return (tora, video_flow_image.cpu().float())
|
||||
|
||||
class ToraEncodeOpticalFlow:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"tora_model": ("TORAMODEL",),
|
||||
"optical_flow": ("IMAGE", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
|
||||
}
|
||||
@ -1133,14 +1166,13 @@ class ToraEncodeOpticalFlow:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, optical_flow, strength):
|
||||
def encode(self, pipeline, optical_flow, strength, tora_model, start_percent, end_percent):
|
||||
check_diffusers_version()
|
||||
B, H, W, C = optical_flow.shape
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
traj_extractor = pipeline["pipe"].traj_extractor
|
||||
vae = pipeline["pipe"].vae
|
||||
vae.enable_slicing()
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
@ -1157,14 +1189,22 @@ class ToraEncodeOpticalFlow:
|
||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
vae.to(offload_device)
|
||||
|
||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
||||
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
|
||||
video_flow_features = torch.stack(video_flow_features)
|
||||
|
||||
video_flow_features = video_flow_features * strength
|
||||
|
||||
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||
|
||||
return (video_flow_features, )
|
||||
tora = {
|
||||
"video_flow_features" : video_flow_features,
|
||||
"start_percent" : start_percent,
|
||||
"end_percent" : end_percent,
|
||||
"traj_extractor" : tora_model["traj_extractor"],
|
||||
"fuser_list" : tora_model["fuser_list"],
|
||||
}
|
||||
|
||||
return (tora, )
|
||||
|
||||
|
||||
|
||||
@ -1227,6 +1267,9 @@ class CogVideoSampler:
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||
|
||||
if tora_trajectory is not None:
|
||||
pipe.transformer.fuser_list = tora_trajectory["fuser_list"]
|
||||
|
||||
if context_options is not None:
|
||||
context_frames = context_options["context_frames"] // 4
|
||||
context_stride = context_options["context_stride"] // 4
|
||||
@ -1262,7 +1305,7 @@ class CogVideoSampler:
|
||||
context_overlap= context_overlap,
|
||||
freenoise=context_options["freenoise"] if context_options is not None else None,
|
||||
controlnet=controlnet,
|
||||
video_flow_features=tora_trajectory if tora_trajectory is not None else None,
|
||||
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||
)
|
||||
if not pipeline["cpu_offloading"]:
|
||||
pipe.transformer.to(offload_device)
|
||||
@ -1809,6 +1852,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
|
||||
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -1831,4 +1875,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
||||
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
||||
}
|
||||
|
||||
@ -161,7 +161,6 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
self.original_mask = original_mask
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
self.traj_extractor = None
|
||||
|
||||
if pab_config is not None:
|
||||
set_pab_manager(pab_config)
|
||||
@ -390,7 +389,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
context_overlap: Optional[int] = None,
|
||||
freenoise: Optional[bool] = True,
|
||||
controlnet: Optional[dict] = None,
|
||||
video_flow_features: Optional[torch.Tensor] = None,
|
||||
tora: Optional[dict] = None,
|
||||
|
||||
):
|
||||
"""
|
||||
@ -582,8 +581,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
if video_flow_features is not None and do_classifier_free_guidance:
|
||||
video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous()
|
||||
if tora is not None and do_classifier_free_guidance:
|
||||
tora["video_flow_features"] = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
|
||||
|
||||
# 9. Controlnet
|
||||
if controlnet is not None:
|
||||
@ -784,11 +783,11 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
else:
|
||||
for c in context_queue:
|
||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||
if video_flow_features is not None:
|
||||
if tora is not None:
|
||||
if do_classifier_free_guidance:
|
||||
partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
|
||||
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
|
||||
else:
|
||||
partial_video_flow_features = video_flow_features[:, c, :, :, :]
|
||||
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :]
|
||||
else:
|
||||
partial_video_flow_features = None
|
||||
|
||||
@ -869,7 +868,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
return_dict=False,
|
||||
controlnet_states=controlnet_states,
|
||||
controlnet_weights=control_weights,
|
||||
video_flow_features=video_flow_features,
|
||||
video_flow_features=tora["video_flow_features"] if (tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user