mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-03 05:53:34 +08:00
testing Tora for I2V
This commit is contained in:
parent
81f8ca676e
commit
a654821515
@ -23,7 +23,7 @@ import numpy as np
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.utils import is_torch_version, logging
|
from diffusers.utils import logging
|
||||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||||
from diffusers.models.attention import Attention, FeedForward
|
from diffusers.models.attention import Attention, FeedForward
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
@ -37,11 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
SAGEATTN_IS_AVAVILABLE = True
|
SAGEATTN_IS_AVAILABLE = True
|
||||||
logger.info("Using sageattn")
|
logger.info("Using sageattn")
|
||||||
except:
|
except:
|
||||||
logger.info("sageattn not found, using sdpa")
|
logger.info("sageattn not found, using sdpa")
|
||||||
SAGEATTN_IS_AVAVILABLE = False
|
SAGEATTN_IS_AVAILABLE = False
|
||||||
|
|
||||||
class CogVideoXAttnProcessor2_0:
|
class CogVideoXAttnProcessor2_0:
|
||||||
r"""
|
r"""
|
||||||
@ -97,7 +97,7 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
|
||||||
if SAGEATTN_IS_AVAVILABLE:
|
if SAGEATTN_IS_AVAILABLE:
|
||||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||||
else:
|
else:
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
@ -171,7 +171,7 @@ class FusedCogVideoXAttnProcessor2_0:
|
|||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
|
||||||
if SAGEATTN_IS_AVAVILABLE:
|
if SAGEATTN_IS_AVAILABLE:
|
||||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||||
else:
|
else:
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
|||||||
1324
examples/cogvideox_5b_Tora_I2V_testing_01.json
Normal file
1324
examples/cogvideox_5b_Tora_I2V_testing_01.json
Normal file
File diff suppressed because it is too large
Load Diff
241
nodes.py
241
nodes.py
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
from comfy.utils import ProgressBar, load_torch_file
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
@ -420,38 +421,6 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
fuse_qkv_projections=True if pab_config is None else False,
|
fuse_qkv_projections=True if pab_config is None else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "Tora" in model:
|
|
||||||
import torch.nn as nn
|
|
||||||
from .tora.traj_module import MGF
|
|
||||||
|
|
||||||
hidden_size = 3072
|
|
||||||
num_layers = transformer.num_layers
|
|
||||||
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
|
||||||
fuser_sd = load_torch_file(os.path.join(base_path, "fuser", "fuser.safetensors"))
|
|
||||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
|
||||||
for module in transformer.fuser_list:
|
|
||||||
for param in module.parameters():
|
|
||||||
param.data = param.data.to(torch.float16)
|
|
||||||
del fuser_sd
|
|
||||||
|
|
||||||
from .tora.traj_module import TrajExtractor
|
|
||||||
traj_extractor = TrajExtractor(
|
|
||||||
vae_downsize=(4, 8, 8),
|
|
||||||
patch_size=2,
|
|
||||||
nums_rb=2,
|
|
||||||
cin=vae.config.latent_channels,
|
|
||||||
channels=[128] * transformer.num_layers,
|
|
||||||
sk=True,
|
|
||||||
use_conv=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
traj_sd = load_torch_file(os.path.join(base_path, "traj_extractor", "traj_extractor.safetensors"))
|
|
||||||
traj_extractor.load_state_dict(traj_sd)
|
|
||||||
traj_extractor.to(torch.float32).to(device)
|
|
||||||
|
|
||||||
pipe.traj_extractor = traj_extractor
|
|
||||||
|
|
||||||
|
|
||||||
pipeline = {
|
pipeline = {
|
||||||
"pipe": pipe,
|
"pipe": pipe,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
@ -622,63 +591,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
vae.load_state_dict(vae_sd)
|
vae.load_state_dict(vae_sd)
|
||||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||||
|
|
||||||
if "Tora" in model:
|
|
||||||
import torch.nn as nn
|
|
||||||
from .tora.traj_module import MGF
|
|
||||||
|
|
||||||
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
|
||||||
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
|
|
||||||
if not os.path.exists(fuser_path):
|
|
||||||
log.info(f"Downloading Fuser model to: {fuser_path}")
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="kijai/CogVideoX-5b-Tora",
|
|
||||||
allow_patterns=["*fuser.safetensors*"],
|
|
||||||
local_dir=download_path,
|
|
||||||
local_dir_use_symlinks=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_size = 3072
|
|
||||||
num_layers = transformer.num_layers
|
|
||||||
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
|
||||||
|
|
||||||
fuser_sd = load_torch_file(fuser_path)
|
|
||||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
|
||||||
for module in transformer.fuser_list:
|
|
||||||
for param in module.parameters():
|
|
||||||
param.data = param.data.to(torch.float16)
|
|
||||||
del fuser_sd
|
|
||||||
|
|
||||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
|
||||||
if not os.path.exists(traj_extractor_path):
|
|
||||||
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="kijai/CogVideoX-5b-Tora",
|
|
||||||
allow_patterns=["*traj_extractor.safetensors*"],
|
|
||||||
local_dir=download_path,
|
|
||||||
local_dir_use_symlinks=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .tora.traj_module import TrajExtractor
|
|
||||||
traj_extractor = TrajExtractor(
|
|
||||||
vae_downsize=(4, 8, 8),
|
|
||||||
patch_size=2,
|
|
||||||
nums_rb=2,
|
|
||||||
cin=vae.config.latent_channels,
|
|
||||||
channels=[128] * transformer.num_layers,
|
|
||||||
sk=True,
|
|
||||||
use_conv=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
traj_sd = load_torch_file(traj_extractor_path)
|
|
||||||
traj_extractor.load_state_dict(traj_sd)
|
|
||||||
traj_extractor.to(torch.float32).to(device)
|
|
||||||
|
|
||||||
pipe.traj_extractor = traj_extractor
|
|
||||||
|
|
||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
@ -694,6 +606,114 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
|
|
||||||
return (pipeline,)
|
return (pipeline,)
|
||||||
|
|
||||||
|
class DownloadAndLoadToraModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": (
|
||||||
|
[
|
||||||
|
"kijai/CogVideoX-5b-Tora",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("TORAMODEL",)
|
||||||
|
RETURN_NAMES = ("tora_model", )
|
||||||
|
FUNCTION = "loadmodel"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "Downloads and loads the the Tora model from Huggingface to 'ComfyUI/models/CogVideo/CogVideoX-5b-Tora'"
|
||||||
|
|
||||||
|
def loadmodel(self, model):
|
||||||
|
|
||||||
|
check_diffusers_version()
|
||||||
|
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
|
download_path = folder_paths.get_folder_paths("CogVideo")[0]
|
||||||
|
|
||||||
|
from .tora.traj_module import MGF
|
||||||
|
|
||||||
|
try:
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
is_accelerate_available = True
|
||||||
|
except:
|
||||||
|
is_accelerate_available = False
|
||||||
|
pass
|
||||||
|
|
||||||
|
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
||||||
|
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
|
||||||
|
if not os.path.exists(fuser_path):
|
||||||
|
log.info(f"Downloading Fuser model to: {fuser_path}")
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=model,
|
||||||
|
allow_patterns=["*fuser.safetensors*"],
|
||||||
|
local_dir=download_path,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_size = 3072
|
||||||
|
num_layers = 42
|
||||||
|
|
||||||
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
|
fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
fuser_sd = load_torch_file(fuser_path)
|
||||||
|
if is_accelerate_available:
|
||||||
|
for key in fuser_sd:
|
||||||
|
set_module_tensor_to_device(fuser_list, key, dtype=torch.float16, device=device, value=fuser_sd[key])
|
||||||
|
else:
|
||||||
|
fuser_list.load_state_dict(fuser_sd)
|
||||||
|
for module in fuser_list:
|
||||||
|
for param in module.parameters():
|
||||||
|
param.data = param.data.to(torch.float16).to(device)
|
||||||
|
del fuser_sd
|
||||||
|
|
||||||
|
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
||||||
|
if not os.path.exists(traj_extractor_path):
|
||||||
|
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="kijai/CogVideoX-5b-Tora",
|
||||||
|
allow_patterns=["*traj_extractor.safetensors*"],
|
||||||
|
local_dir=download_path,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .tora.traj_module import TrajExtractor
|
||||||
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
|
traj_extractor = TrajExtractor(
|
||||||
|
vae_downsize=(4, 8, 8),
|
||||||
|
patch_size=2,
|
||||||
|
nums_rb=2,
|
||||||
|
cin=16,
|
||||||
|
channels=[128] * 42,
|
||||||
|
sk=True,
|
||||||
|
use_conv=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
traj_sd = load_torch_file(traj_extractor_path)
|
||||||
|
if is_accelerate_available:
|
||||||
|
for key in traj_sd:
|
||||||
|
set_module_tensor_to_device(traj_extractor, key, dtype=torch.float32, device=device, value=traj_sd[key])
|
||||||
|
else:
|
||||||
|
traj_extractor.load_state_dict(traj_sd)
|
||||||
|
traj_extractor.to(torch.float32).to(device)
|
||||||
|
|
||||||
|
toramodel = {
|
||||||
|
"fuser_list": fuser_list,
|
||||||
|
"traj_extractor": traj_extractor,
|
||||||
|
}
|
||||||
|
|
||||||
|
return (toramodel,)
|
||||||
|
|
||||||
class DownloadAndLoadCogVideoControlNet:
|
class DownloadAndLoadCogVideoControlNet:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1060,11 +1080,14 @@ class ToraEncodeTrajectory:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"pipeline": ("COGVIDEOPIPE",),
|
"pipeline": ("COGVIDEOPIPE",),
|
||||||
|
"tora_model": ("TORAMODEL",),
|
||||||
"coordinates": ("STRING", {"forceInput": True}),
|
"coordinates": ("STRING", {"forceInput": True}),
|
||||||
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
||||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||||
"num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}),
|
"num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1073,13 +1096,12 @@ class ToraEncodeTrajectory:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, width, height, num_frames, coordinates, strength):
|
def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model):
|
||||||
check_diffusers_version()
|
check_diffusers_version()
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
|
||||||
traj_extractor = pipeline["pipe"].traj_extractor
|
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
vae.enable_slicing()
|
vae.enable_slicing()
|
||||||
vae._clear_fake_context_parallel_cache()
|
vae._clear_fake_context_parallel_cache()
|
||||||
@ -1108,22 +1130,33 @@ class ToraEncodeTrajectory:
|
|||||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
|
|
||||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
|
||||||
video_flow_features = torch.stack(video_flow_features)
|
video_flow_features = torch.stack(video_flow_features)
|
||||||
|
|
||||||
video_flow_features = video_flow_features * strength
|
video_flow_features = video_flow_features * strength
|
||||||
|
|
||||||
logging.info(f"video_flow shape: {video_flow.shape}")
|
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||||
|
|
||||||
return (video_flow_features, video_flow_image.cpu().float())
|
tora = {
|
||||||
|
"video_flow_features" : video_flow_features,
|
||||||
|
"start_percent" : start_percent,
|
||||||
|
"end_percent" : end_percent,
|
||||||
|
"traj_extractor" : tora_model["traj_extractor"],
|
||||||
|
"fuser_list" : tora_model["fuser_list"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return (tora, video_flow_image.cpu().float())
|
||||||
|
|
||||||
class ToraEncodeOpticalFlow:
|
class ToraEncodeOpticalFlow:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"pipeline": ("COGVIDEOPIPE",),
|
"pipeline": ("COGVIDEOPIPE",),
|
||||||
|
"tora_model": ("TORAMODEL",),
|
||||||
"optical_flow": ("IMAGE", ),
|
"optical_flow": ("IMAGE", ),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
},
|
},
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -1133,14 +1166,13 @@ class ToraEncodeOpticalFlow:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, optical_flow, strength):
|
def encode(self, pipeline, optical_flow, strength, tora_model, start_percent, end_percent):
|
||||||
check_diffusers_version()
|
check_diffusers_version()
|
||||||
B, H, W, C = optical_flow.shape
|
B, H, W, C = optical_flow.shape
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
|
||||||
traj_extractor = pipeline["pipe"].traj_extractor
|
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
vae.enable_slicing()
|
vae.enable_slicing()
|
||||||
vae._clear_fake_context_parallel_cache()
|
vae._clear_fake_context_parallel_cache()
|
||||||
@ -1157,14 +1189,22 @@ class ToraEncodeOpticalFlow:
|
|||||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
|
|
||||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
|
||||||
video_flow_features = torch.stack(video_flow_features)
|
video_flow_features = torch.stack(video_flow_features)
|
||||||
|
|
||||||
video_flow_features = video_flow_features * strength
|
video_flow_features = video_flow_features * strength
|
||||||
|
|
||||||
logging.info(f"video_flow shape: {video_flow.shape}")
|
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||||
|
|
||||||
return (video_flow_features, )
|
tora = {
|
||||||
|
"video_flow_features" : video_flow_features,
|
||||||
|
"start_percent" : start_percent,
|
||||||
|
"end_percent" : end_percent,
|
||||||
|
"traj_extractor" : tora_model["traj_extractor"],
|
||||||
|
"fuser_list" : tora_model["fuser_list"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return (tora, )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1227,6 +1267,9 @@ class CogVideoSampler:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown scheduler: {scheduler}")
|
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||||
|
|
||||||
|
if tora_trajectory is not None:
|
||||||
|
pipe.transformer.fuser_list = tora_trajectory["fuser_list"]
|
||||||
|
|
||||||
if context_options is not None:
|
if context_options is not None:
|
||||||
context_frames = context_options["context_frames"] // 4
|
context_frames = context_options["context_frames"] // 4
|
||||||
context_stride = context_options["context_stride"] // 4
|
context_stride = context_options["context_stride"] // 4
|
||||||
@ -1262,7 +1305,7 @@ class CogVideoSampler:
|
|||||||
context_overlap= context_overlap,
|
context_overlap= context_overlap,
|
||||||
freenoise=context_options["freenoise"] if context_options is not None else None,
|
freenoise=context_options["freenoise"] if context_options is not None else None,
|
||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
video_flow_features=tora_trajectory if tora_trajectory is not None else None,
|
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||||
)
|
)
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
pipe.transformer.to(offload_device)
|
pipe.transformer.to(offload_device)
|
||||||
@ -1809,6 +1852,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
||||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||||
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
|
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
|
||||||
|
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
@ -1831,4 +1875,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
||||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||||
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
||||||
|
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -161,7 +161,6 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
self.original_mask = original_mask
|
self.original_mask = original_mask
|
||||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
|
|
||||||
self.traj_extractor = None
|
|
||||||
|
|
||||||
if pab_config is not None:
|
if pab_config is not None:
|
||||||
set_pab_manager(pab_config)
|
set_pab_manager(pab_config)
|
||||||
@ -390,7 +389,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
context_overlap: Optional[int] = None,
|
context_overlap: Optional[int] = None,
|
||||||
freenoise: Optional[bool] = True,
|
freenoise: Optional[bool] = True,
|
||||||
controlnet: Optional[dict] = None,
|
controlnet: Optional[dict] = None,
|
||||||
video_flow_features: Optional[torch.Tensor] = None,
|
tora: Optional[dict] = None,
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -582,8 +581,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
if video_flow_features is not None and do_classifier_free_guidance:
|
if tora is not None and do_classifier_free_guidance:
|
||||||
video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous()
|
tora["video_flow_features"] = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
|
||||||
|
|
||||||
# 9. Controlnet
|
# 9. Controlnet
|
||||||
if controlnet is not None:
|
if controlnet is not None:
|
||||||
@ -784,11 +783,11 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
else:
|
else:
|
||||||
for c in context_queue:
|
for c in context_queue:
|
||||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||||
if video_flow_features is not None:
|
if tora is not None:
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
|
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
|
||||||
else:
|
else:
|
||||||
partial_video_flow_features = video_flow_features[:, c, :, :, :]
|
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :]
|
||||||
else:
|
else:
|
||||||
partial_video_flow_features = None
|
partial_video_flow_features = None
|
||||||
|
|
||||||
@ -869,7 +868,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
return_dict=False,
|
return_dict=False,
|
||||||
controlnet_states=controlnet_states,
|
controlnet_states=controlnet_states,
|
||||||
controlnet_weights=control_weights,
|
controlnet_weights=control_weights,
|
||||||
video_flow_features=video_flow_features,
|
video_flow_features=tora["video_flow_features"] if (tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
|
||||||
)[0]
|
)[0]
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user