From adfc560ad62cef666fde7d6df5e0a720a8868f3f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 21 Oct 2024 18:19:15 +0300 Subject: [PATCH] Add ToraEncodeOpticalFlow -node For using custom optical flows as input for the Tora model --- nodes.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 9c35389..af71957 100644 --- a/nodes.py +++ b/nodes.py @@ -1064,15 +1064,16 @@ class ToraEncodeTrajectory: "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), "num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }, } - RETURN_TYPES = ("TORAFEATURES", ) - RETURN_NAMES = ("tora_trajectory", ) + RETURN_TYPES = ("TORAFEATURES", "IMAGE", ) + RETURN_NAMES = ("tora_trajectory", "video_flow_images", ) FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, width, height, num_frames, coordinates): + def encode(self, pipeline, width, height, num_frames, coordinates, strength): check_diffusers_version() device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -1091,10 +1092,13 @@ class ToraEncodeTrajectory: video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device) video_flow = rearrange(video_flow, "T H W C -> T C H W") video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W] + video_flow = ( rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype) ) + video_flow_image = rearrange(video_flow, "B C T H W -> (B T) H W C") + print(video_flow_image.shape) mm.soft_empty_cache() # VAE encode @@ -1107,9 +1111,60 @@ class ToraEncodeTrajectory: video_flow_features = traj_extractor(video_flow.to(torch.float32)) video_flow_features = torch.stack(video_flow_features) + video_flow_features = video_flow_features * strength + logging.info(f"video_flow shape: {video_flow.shape}") - return (video_flow_features,) + return (video_flow_features, video_flow_image.cpu().float()) + +class ToraEncodeOpticalFlow: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "pipeline": ("COGVIDEOPIPE",), + "optical_flow": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }, + + } + + RETURN_TYPES = ("TORAFEATURES",) + RETURN_NAMES = ("tora_trajectory",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, pipeline, optical_flow, strength): + check_diffusers_version() + B, H, W, C = optical_flow.shape + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + generator = torch.Generator(device=device).manual_seed(0) + + traj_extractor = pipeline["pipe"].traj_extractor + vae = pipeline["pipe"].vae + vae.enable_slicing() + vae._clear_fake_context_parallel_cache() + + video_flow = optical_flow * 2 - 1 + video_flow = rearrange(video_flow, "(B T) H W C -> B C T H W", T=B, B=1) + print(video_flow.shape) + mm.soft_empty_cache() + + # VAE encode + if not pipeline["cpu_offloading"]: + vae.to(device) + video_flow = video_flow.to(vae.dtype).to(vae.device) + video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor + vae.to(offload_device) + + video_flow_features = traj_extractor(video_flow.to(torch.float32)) + video_flow_features = torch.stack(video_flow_features) + + video_flow_features = video_flow_features * strength + + logging.info(f"video_flow shape: {video_flow.shape}") + + return (video_flow_features, ) @@ -1753,6 +1808,7 @@ NODE_CLASS_MAPPINGS = { "CogVideoControlNet": CogVideoControlNet, "DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet, "ToraEncodeTrajectory": ToraEncodeTrajectory, + "ToraEncodeOpticalFlow": ToraEncodeOpticalFlow, } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1774,4 +1830,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoContextOptions": "CogVideo Context Options", "DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet", "ToraEncodeTrajectory": "Tora Encode Trajectory", + "ToraEncodeOpticalFlow": "Tora Encode OpticalFlow", }