testing Tora for I2V

This commit is contained in:
kijai 2024-10-21 22:53:36 +03:00
parent 81f8ca676e
commit a654821515
4 changed files with 1479 additions and 111 deletions

View File

@ -23,7 +23,7 @@ import numpy as np
from einops import rearrange
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import is_torch_version, logging
from diffusers.utils import logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import AttentionProcessor
@ -37,11 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True
SAGEATTN_IS_AVAILABLE = True
logger.info("Using sageattn")
except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAVILABLE = False
SAGEATTN_IS_AVAILABLE = False
class CogVideoXAttnProcessor2_0:
r"""
@ -97,7 +97,7 @@ class CogVideoXAttnProcessor2_0:
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
@ -171,7 +171,7 @@ class FusedCogVideoXAttnProcessor2_0:
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(

File diff suppressed because it is too large Load Diff

241
nodes.py
View File

@ -1,5 +1,6 @@
import os
import torch
import torch.nn as nn
import folder_paths
import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
@ -420,38 +421,6 @@ class DownloadAndLoadCogVideoModel:
fuse_qkv_projections=True if pab_config is None else False,
)
if "Tora" in model:
import torch.nn as nn
from .tora.traj_module import MGF
hidden_size = 3072
num_layers = transformer.num_layers
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
fuser_sd = load_torch_file(os.path.join(base_path, "fuser", "fuser.safetensors"))
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
for module in transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(torch.float16)
del fuser_sd
from .tora.traj_module import TrajExtractor
traj_extractor = TrajExtractor(
vae_downsize=(4, 8, 8),
patch_size=2,
nums_rb=2,
cin=vae.config.latent_channels,
channels=[128] * transformer.num_layers,
sk=True,
use_conv=False,
)
traj_sd = load_torch_file(os.path.join(base_path, "traj_extractor", "traj_extractor.safetensors"))
traj_extractor.load_state_dict(traj_sd)
traj_extractor.to(torch.float32).to(device)
pipe.traj_extractor = traj_extractor
pipeline = {
"pipe": pipe,
"dtype": dtype,
@ -622,63 +591,6 @@ class DownloadAndLoadCogVideoGGUFModel:
vae.load_state_dict(vae_sd)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
if "Tora" in model:
import torch.nn as nn
from .tora.traj_module import MGF
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
if not os.path.exists(fuser_path):
log.info(f"Downloading Fuser model to: {fuser_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="kijai/CogVideoX-5b-Tora",
allow_patterns=["*fuser.safetensors*"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
hidden_size = 3072
num_layers = transformer.num_layers
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
fuser_sd = load_torch_file(fuser_path)
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
for module in transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(torch.float16)
del fuser_sd
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
if not os.path.exists(traj_extractor_path):
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="kijai/CogVideoX-5b-Tora",
allow_patterns=["*traj_extractor.safetensors*"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
from .tora.traj_module import TrajExtractor
traj_extractor = TrajExtractor(
vae_downsize=(4, 8, 8),
patch_size=2,
nums_rb=2,
cin=vae.config.latent_channels,
channels=[128] * transformer.num_layers,
sk=True,
use_conv=False,
)
traj_sd = load_torch_file(traj_extractor_path)
traj_extractor.load_state_dict(traj_sd)
traj_extractor.to(torch.float32).to(device)
pipe.traj_extractor = traj_extractor
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
@ -694,6 +606,114 @@ class DownloadAndLoadCogVideoGGUFModel:
return (pipeline,)
class DownloadAndLoadToraModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (
[
"kijai/CogVideoX-5b-Tora",
],
),
},
}
RETURN_TYPES = ("TORAMODEL",)
RETURN_NAMES = ("tora_model", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the the Tora model from Huggingface to 'ComfyUI/models/CogVideo/CogVideoX-5b-Tora'"
def loadmodel(self, model):
check_diffusers_version()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
download_path = folder_paths.get_folder_paths("CogVideo")[0]
from .tora.traj_module import MGF
try:
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
is_accelerate_available = True
except:
is_accelerate_available = False
pass
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
if not os.path.exists(fuser_path):
log.info(f"Downloading Fuser model to: {fuser_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=model,
allow_patterns=["*fuser.safetensors*"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
hidden_size = 3072
num_layers = 42
with (init_empty_weights() if is_accelerate_available else nullcontext()):
fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
fuser_sd = load_torch_file(fuser_path)
if is_accelerate_available:
for key in fuser_sd:
set_module_tensor_to_device(fuser_list, key, dtype=torch.float16, device=device, value=fuser_sd[key])
else:
fuser_list.load_state_dict(fuser_sd)
for module in fuser_list:
for param in module.parameters():
param.data = param.data.to(torch.float16).to(device)
del fuser_sd
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
if not os.path.exists(traj_extractor_path):
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="kijai/CogVideoX-5b-Tora",
allow_patterns=["*traj_extractor.safetensors*"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
from .tora.traj_module import TrajExtractor
with (init_empty_weights() if is_accelerate_available else nullcontext()):
traj_extractor = TrajExtractor(
vae_downsize=(4, 8, 8),
patch_size=2,
nums_rb=2,
cin=16,
channels=[128] * 42,
sk=True,
use_conv=False,
)
traj_sd = load_torch_file(traj_extractor_path)
if is_accelerate_available:
for key in traj_sd:
set_module_tensor_to_device(traj_extractor, key, dtype=torch.float32, device=device, value=traj_sd[key])
else:
traj_extractor.load_state_dict(traj_sd)
traj_extractor.to(torch.float32).to(device)
toramodel = {
"fuser_list": fuser_list,
"traj_extractor": traj_extractor,
}
return (toramodel,)
class DownloadAndLoadCogVideoControlNet:
@classmethod
def INPUT_TYPES(s):
@ -1060,11 +1080,14 @@ class ToraEncodeTrajectory:
def INPUT_TYPES(s):
return {"required": {
"pipeline": ("COGVIDEOPIPE",),
"tora_model": ("TORAMODEL",),
"coordinates": ("STRING", {"forceInput": True}),
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
"num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
@ -1073,13 +1096,12 @@ class ToraEncodeTrajectory:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, width, height, num_frames, coordinates, strength):
def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model):
check_diffusers_version()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
traj_extractor = pipeline["pipe"].traj_extractor
vae = pipeline["pipe"].vae
vae.enable_slicing()
vae._clear_fake_context_parallel_cache()
@ -1108,22 +1130,33 @@ class ToraEncodeTrajectory:
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
vae.to(offload_device)
video_flow_features = traj_extractor(video_flow.to(torch.float32))
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
video_flow_features = torch.stack(video_flow_features)
video_flow_features = video_flow_features * strength
logging.info(f"video_flow shape: {video_flow.shape}")
return (video_flow_features, video_flow_image.cpu().float())
tora = {
"video_flow_features" : video_flow_features,
"start_percent" : start_percent,
"end_percent" : end_percent,
"traj_extractor" : tora_model["traj_extractor"],
"fuser_list" : tora_model["fuser_list"],
}
return (tora, video_flow_image.cpu().float())
class ToraEncodeOpticalFlow:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"pipeline": ("COGVIDEOPIPE",),
"tora_model": ("TORAMODEL",),
"optical_flow": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
@ -1133,14 +1166,13 @@ class ToraEncodeOpticalFlow:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, optical_flow, strength):
def encode(self, pipeline, optical_flow, strength, tora_model, start_percent, end_percent):
check_diffusers_version()
B, H, W, C = optical_flow.shape
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
traj_extractor = pipeline["pipe"].traj_extractor
vae = pipeline["pipe"].vae
vae.enable_slicing()
vae._clear_fake_context_parallel_cache()
@ -1157,14 +1189,22 @@ class ToraEncodeOpticalFlow:
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
vae.to(offload_device)
video_flow_features = traj_extractor(video_flow.to(torch.float32))
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
video_flow_features = torch.stack(video_flow_features)
video_flow_features = video_flow_features * strength
logging.info(f"video_flow shape: {video_flow.shape}")
return (video_flow_features, )
tora = {
"video_flow_features" : video_flow_features,
"start_percent" : start_percent,
"end_percent" : end_percent,
"traj_extractor" : tora_model["traj_extractor"],
"fuser_list" : tora_model["fuser_list"],
}
return (tora, )
@ -1227,6 +1267,9 @@ class CogVideoSampler:
else:
raise ValueError(f"Unknown scheduler: {scheduler}")
if tora_trajectory is not None:
pipe.transformer.fuser_list = tora_trajectory["fuser_list"]
if context_options is not None:
context_frames = context_options["context_frames"] // 4
context_stride = context_options["context_stride"] // 4
@ -1262,7 +1305,7 @@ class CogVideoSampler:
context_overlap= context_overlap,
freenoise=context_options["freenoise"] if context_options is not None else None,
controlnet=controlnet,
video_flow_features=tora_trajectory if tora_trajectory is not None else None,
tora=tora_trajectory if tora_trajectory is not None else None,
)
if not pipeline["cpu_offloading"]:
pipe.transformer.to(offload_device)
@ -1809,6 +1852,7 @@ NODE_CLASS_MAPPINGS = {
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
"ToraEncodeTrajectory": ToraEncodeTrajectory,
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1831,4 +1875,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
"ToraEncodeTrajectory": "Tora Encode Trajectory",
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
"DownloadAndLoadToraModel": "(Down)load Tora Model",
}

View File

@ -161,7 +161,6 @@ class CogVideoXPipeline(VideoSysPipeline):
self.original_mask = original_mask
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.traj_extractor = None
if pab_config is not None:
set_pab_manager(pab_config)
@ -390,7 +389,7 @@ class CogVideoXPipeline(VideoSysPipeline):
context_overlap: Optional[int] = None,
freenoise: Optional[bool] = True,
controlnet: Optional[dict] = None,
video_flow_features: Optional[torch.Tensor] = None,
tora: Optional[dict] = None,
):
"""
@ -582,8 +581,8 @@ class CogVideoXPipeline(VideoSysPipeline):
if self.transformer.config.use_rotary_positional_embeddings
else None
)
if video_flow_features is not None and do_classifier_free_guidance:
video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous()
if tora is not None and do_classifier_free_guidance:
tora["video_flow_features"] = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
# 9. Controlnet
if controlnet is not None:
@ -784,11 +783,11 @@ class CogVideoXPipeline(VideoSysPipeline):
else:
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
if video_flow_features is not None:
if tora is not None:
if do_classifier_free_guidance:
partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
else:
partial_video_flow_features = video_flow_features[:, c, :, :, :]
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :]
else:
partial_video_flow_features = None
@ -869,7 +868,7 @@ class CogVideoXPipeline(VideoSysPipeline):
return_dict=False,
controlnet_states=controlnet_states,
controlnet_weights=control_weights,
video_flow_features=video_flow_features,
video_flow_features=tora["video_flow_features"] if (tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0]
noise_pred = noise_pred.float()