mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 22:14:23 +08:00
Add ToraEncodeOpticalFlow -node
For using custom optical flows as input for the Tora model
This commit is contained in:
parent
4c3fcd7b01
commit
adfc560ad6
65
nodes.py
65
nodes.py
@ -1064,15 +1064,16 @@ class ToraEncodeTrajectory:
|
||||
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||
"num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TORAFEATURES", )
|
||||
RETURN_NAMES = ("tora_trajectory", )
|
||||
RETURN_TYPES = ("TORAFEATURES", "IMAGE", )
|
||||
RETURN_NAMES = ("tora_trajectory", "video_flow_images", )
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, width, height, num_frames, coordinates):
|
||||
def encode(self, pipeline, width, height, num_frames, coordinates, strength):
|
||||
check_diffusers_version()
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -1091,10 +1092,13 @@ class ToraEncodeTrajectory:
|
||||
video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device)
|
||||
video_flow = rearrange(video_flow, "T H W C -> T C H W")
|
||||
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
|
||||
|
||||
|
||||
video_flow = (
|
||||
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype)
|
||||
)
|
||||
video_flow_image = rearrange(video_flow, "B C T H W -> (B T) H W C")
|
||||
print(video_flow_image.shape)
|
||||
mm.soft_empty_cache()
|
||||
|
||||
# VAE encode
|
||||
@ -1107,9 +1111,60 @@ class ToraEncodeTrajectory:
|
||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
||||
video_flow_features = torch.stack(video_flow_features)
|
||||
|
||||
video_flow_features = video_flow_features * strength
|
||||
|
||||
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||
|
||||
return (video_flow_features,)
|
||||
return (video_flow_features, video_flow_image.cpu().float())
|
||||
|
||||
class ToraEncodeOpticalFlow:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"optical_flow": ("IMAGE", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TORAFEATURES",)
|
||||
RETURN_NAMES = ("tora_trajectory",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, optical_flow, strength):
|
||||
check_diffusers_version()
|
||||
B, H, W, C = optical_flow.shape
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
traj_extractor = pipeline["pipe"].traj_extractor
|
||||
vae = pipeline["pipe"].vae
|
||||
vae.enable_slicing()
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
|
||||
video_flow = optical_flow * 2 - 1
|
||||
video_flow = rearrange(video_flow, "(B T) H W C -> B C T H W", T=B, B=1)
|
||||
print(video_flow.shape)
|
||||
mm.soft_empty_cache()
|
||||
|
||||
# VAE encode
|
||||
if not pipeline["cpu_offloading"]:
|
||||
vae.to(device)
|
||||
video_flow = video_flow.to(vae.dtype).to(vae.device)
|
||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
vae.to(offload_device)
|
||||
|
||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
||||
video_flow_features = torch.stack(video_flow_features)
|
||||
|
||||
video_flow_features = video_flow_features * strength
|
||||
|
||||
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||
|
||||
return (video_flow_features, )
|
||||
|
||||
|
||||
|
||||
@ -1753,6 +1808,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoControlNet": CogVideoControlNet,
|
||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -1774,4 +1830,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoContextOptions": "CogVideo Context Options",
|
||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user