diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 730fef9..daac903 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -23,7 +23,7 @@ import numpy as np from einops import rearrange from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import is_torch_version, logging +from diffusers.utils import logging from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.attention import Attention, FeedForward from diffusers.models.attention_processor import AttentionProcessor @@ -37,11 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name try: from sageattention import sageattn - SAGEATTN_IS_AVAVILABLE = True + SAGEATTN_IS_AVAILABLE = True logger.info("Using sageattn") except: logger.info("sageattn not found, using sdpa") - SAGEATTN_IS_AVAVILABLE = False + SAGEATTN_IS_AVAILABLE = False class CogVideoXAttnProcessor2_0: r""" @@ -97,7 +97,7 @@ class CogVideoXAttnProcessor2_0: if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - if SAGEATTN_IS_AVAVILABLE: + if SAGEATTN_IS_AVAILABLE: hidden_states = sageattn(query, key, value, is_causal=False) else: hidden_states = F.scaled_dot_product_attention( @@ -171,7 +171,7 @@ class FusedCogVideoXAttnProcessor2_0: if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - if SAGEATTN_IS_AVAVILABLE: + if SAGEATTN_IS_AVAILABLE: hidden_states = sageattn(query, key, value, is_causal=False) else: hidden_states = F.scaled_dot_product_attention( diff --git a/examples/cogvideox_5b_Tora_I2V_testing_01.json b/examples/cogvideox_5b_Tora_I2V_testing_01.json new file mode 100644 index 0000000..39f1ed1 --- /dev/null +++ b/examples/cogvideox_5b_Tora_I2V_testing_01.json @@ -0,0 +1,1324 @@ +{ + "last_node_id": 75, + "last_link_id": 176, + "nodes": [ + { + "id": 20, + "type": "CLIPLoader", + "pos": { + "0": -26, + "1": 400 + }, + "size": { + "0": 451.30548095703125, + "1": 82 + }, + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 54, + 56 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CLIPLoader" + }, + "widgets_values": [ + "t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors", + "sd3" + ] + }, + { + "id": 31, + "type": "CogVideoTextEncode", + "pos": { + "0": 497, + "1": 520 + }, + "size": { + "0": 463.01251220703125, + "1": 124 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 56 + } + ], + "outputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "links": [ + 123 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoTextEncode" + }, + "widgets_values": [ + "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. ", + 1, + true + ] + }, + { + "id": 65, + "type": "CreateShapeImageOnPath", + "pos": { + "0": 1052, + "1": 935 + }, + "size": { + "0": 313.4619445800781, + "1": 270 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "coordinates", + "type": "STRING", + "link": 145, + "widget": { + "name": "coordinates" + } + }, + { + "name": "size_multiplier", + "type": "FLOAT", + "link": null, + "widget": { + "name": "size_multiplier" + }, + "shape": 7 + }, + { + "name": "frame_width", + "type": "INT", + "link": 149, + "widget": { + "name": "frame_width" + } + }, + { + "name": "frame_height", + "type": "INT", + "link": 150, + "widget": { + "name": "frame_height" + } + } + ], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 142, + 153 + ], + "slot_index": 0 + }, + { + "name": "mask", + "type": "MASK", + "links": [ + 154 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "CreateShapeImageOnPath" + }, + "widgets_values": [ + "circle", + "", + 512, + 512, + 12, + 12, + "red", + "black", + 0, + 1, + [ + 1 + ] + ] + }, + { + "id": 66, + "type": "VHS_VideoCombine", + "pos": { + "0": 1405, + "1": 916 + }, + "size": [ + 605.3909912109375, + 714.2606608072917 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 142 + }, + { + "name": "audio", + "type": "AUDIO", + "link": null, + "shape": 7 + }, + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null, + "shape": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "Filenames", + "type": "VHS_FILENAMES", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_VideoCombine" + }, + "widgets_values": { + "frame_rate": 8, + "loop_count": 0, + "filename_prefix": "CogVideoX-Tora-trajectory", + "format": "video/h264-mp4", + "pix_fmt": "yuv420p", + "crf": 19, + "save_metadata": true, + "pingpong": false, + "save_output": false, + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "filename": "CogVideoX-Tora-trajectory_00011.mp4", + "subfolder": "", + "type": "temp", + "format": "video/h264-mp4", + "frame_rate": 8 + }, + "muted": false + } + } + }, + { + "id": 56, + "type": "CogVideoDecode", + "pos": { + "0": 1596, + "1": 150 + }, + "size": { + "0": 300.396484375, + "1": 198 + }, + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "pipeline", + "type": "COGVIDEOPIPE", + "link": 128 + }, + { + "name": "samples", + "type": "LATENT", + "link": 127 + } + ], + "outputs": [ + { + "name": "images", + "type": "IMAGE", + "links": [ + 155 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoDecode" + }, + "widgets_values": [ + false, + 240, + 360, + 0.2, + 0.2, + true + ] + }, + { + "id": 57, + "type": "CogVideoSampler", + "pos": { + "0": 1138, + "1": 150 + }, + "size": { + "0": 399.8780822753906, + "1": 390 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "pipeline", + "type": "COGVIDEOPIPE", + "link": 121 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 122 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 123 + }, + { + "name": "samples", + "type": "LATENT", + "link": null, + "shape": 7 + }, + { + "name": "image_cond_latents", + "type": "LATENT", + "link": 162, + "shape": 7 + }, + { + "name": "context_options", + "type": "COGCONTEXT", + "link": null, + "shape": 7 + }, + { + "name": "controlnet", + "type": "COGVIDECONTROLNET", + "link": null, + "shape": 7 + }, + { + "name": "tora_trajectory", + "type": "TORAFEATURES", + "link": 173, + "shape": 7 + }, + { + "name": "num_frames", + "type": "INT", + "link": 157, + "widget": { + "name": "num_frames" + } + }, + { + "name": "height", + "type": "INT", + "link": 151, + "widget": { + "name": "height" + } + }, + { + "name": "width", + "type": "INT", + "link": 152, + "widget": { + "name": "width" + } + } + ], + "outputs": [ + { + "name": "cogvideo_pipe", + "type": "COGVIDEOPIPE", + "links": [ + 128 + ], + "slot_index": 0, + "shape": 3 + }, + { + "name": "samples", + "type": "LATENT", + "links": [ + 127 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoSampler" + }, + "widgets_values": [ + 480, + 720, + 49, + 32, + 6, + 65334758276105, + "fixed", + "CogVideoXDPMScheduler", + 1 + ] + }, + { + "id": 71, + "type": "CogVideoImageEncode", + "pos": { + "0": 68.59265899658203, + "1": 573.0311889648438 + }, + "size": { + "0": 315, + "1": 122 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "pipeline", + "type": "COGVIDEOPIPE", + "link": 164 + }, + { + "name": "image", + "type": "IMAGE", + "link": 167 + }, + { + "name": "mask", + "type": "MASK", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "samples", + "type": "LATENT", + "links": [ + 162 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CogVideoImageEncode" + }, + "widgets_values": [ + 16, + false + ] + }, + { + "id": 67, + "type": "GetMaskSizeAndCount", + "pos": { + "0": 763, + "1": 772 + }, + "size": { + "0": 264.5999755859375, + "1": 86 + }, + "flags": { + "collapsed": true + }, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 146 + } + ], + "outputs": [ + { + "name": "mask", + "type": "MASK", + "links": null + }, + { + "name": "width", + "type": "INT", + "links": [ + 149, + 152, + 168, + 171 + ], + "slot_index": 1 + }, + { + "name": "height", + "type": "INT", + "links": [ + 150, + 151, + 169, + 172 + ], + "slot_index": 2 + }, + { + "name": "count", + "type": "INT", + "links": [ + 157, + 170 + ], + "slot_index": 3 + } + ], + "properties": { + "Node name for S&R": "GetMaskSizeAndCount" + }, + "widgets_values": [] + }, + { + "id": 72, + "type": "LoadImage", + "pos": { + "0": -820, + "1": 531 + }, + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 166 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "sd3stag.png", + "image" + ] + }, + { + "id": 73, + "type": "ImageResizeKJ", + "pos": { + "0": -436, + "1": 527 + }, + "size": { + "0": 315, + "1": 266 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 166 + }, + { + "name": "get_image_size", + "type": "IMAGE", + "link": null, + "shape": 7 + }, + { + "name": "width_input", + "type": "INT", + "link": 168, + "widget": { + "name": "width_input" + }, + "shape": 7 + }, + { + "name": "height_input", + "type": "INT", + "link": 169, + "widget": { + "name": "height_input" + }, + "shape": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 167 + ], + "slot_index": 0 + }, + { + "name": "width", + "type": "INT", + "links": null + }, + { + "name": "height", + "type": "INT", + "links": null + } + ], + "properties": { + "Node name for S&R": "ImageResizeKJ" + }, + "widgets_values": [ + 512, + 512, + "nearest-exact", + false, + 2, + 0, + 0, + "disabled" + ] + }, + { + "id": 68, + "type": "ImageCompositeMasked", + "pos": { + "0": 1674, + "1": 641 + }, + "size": { + "0": 315, + "1": 146 + }, + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "destination", + "type": "IMAGE", + "link": 155 + }, + { + "name": "source", + "type": "IMAGE", + "link": 153 + }, + { + "name": "mask", + "type": "MASK", + "link": 154, + "shape": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 156 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ImageCompositeMasked" + }, + "widgets_values": [ + 0, + 0, + false + ] + }, + { + "id": 60, + "type": "SplineEditor", + "pos": { + "0": -103, + "1": 770 + }, + "size": { + "0": 765, + "1": 910 + }, + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "mask", + "type": "MASK", + "links": [ + 146 + ], + "slot_index": 0 + }, + { + "name": "coord_str", + "type": "STRING", + "links": [ + 145, + 176 + ], + "slot_index": 1 + }, + { + "name": "float", + "type": "FLOAT", + "links": null + }, + { + "name": "count", + "type": "INT", + "links": null + }, + { + "name": "normalized_str", + "type": "STRING", + "links": null + } + ], + "properties": { + "Node name for S&R": "SplineEditor", + "points": "SplineEditor" + }, + "widgets_values": [ + "[{\"x\":366.43744764656,\"y\":171.3214040944956},{\"x\":466.3749333683491,\"y\":177.6666412831806},{\"x\":539.3451610382268,\"y\":195.1160435520644},{\"x\":276.01781770779843,\"y\":199.87497144357818}]", + "[{\"x\":366.43743896484375,\"y\":171.3214111328125},{\"x\":373.86798095703125,\"y\":171.79318237304688},{\"x\":381.29852294921875,\"y\":172.26495361328125},{\"x\":388.7288818359375,\"y\":172.73956298828125},{\"x\":396.1580810546875,\"y\":173.23184204101562},{\"x\":403.58544921875,\"y\":173.7510223388672},{\"x\":411.0102233886719,\"y\":174.30575561523438},{\"x\":418.4319763183594,\"y\":174.8998260498047},{\"x\":425.85003662109375,\"y\":175.53823852539062},{\"x\":433.26348876953125,\"y\":176.2280731201172},{\"x\":440.67156982421875,\"y\":176.9736328125},{\"x\":448.0726623535156,\"y\":177.78512573242188},{\"x\":455.4649658203125,\"y\":178.67330932617188},{\"x\":462.8458557128906,\"y\":179.65150451660156},{\"x\":470.2113952636719,\"y\":180.73902893066406},{\"x\":477.5547180175781,\"y\":181.96739196777344},{\"x\":484.8601379394531,\"y\":183.40267944335938},{\"x\":492.0770568847656,\"y\":185.22531127929688},{\"x\":498.24371337890625,\"y\":188.81117248535156},{\"x\":491.68231201171875,\"y\":191.73179626464844},{\"x\":484.3272705078125,\"y\":192.8770294189453},{\"x\":476.9224853515625,\"y\":193.65155029296875},{\"x\":469.50146484375,\"y\":194.25323486328125},{\"x\":462.07281494140625,\"y\":194.7535400390625},{\"x\":454.6398620605469,\"y\":195.1853790283203},{\"x\":447.2041931152344,\"y\":195.56698608398438},{\"x\":439.7665710449219,\"y\":195.90963745117188},{\"x\":432.32757568359375,\"y\":196.2206573486328},{\"x\":424.8875427246094,\"y\":196.50531005859375},{\"x\":417.4466552734375,\"y\":196.76824951171875},{\"x\":410.0051574707031,\"y\":197.01141357421875},{\"x\":402.5631103515625,\"y\":197.23898315429688},{\"x\":395.1206970214844,\"y\":197.45263671875},{\"x\":387.6778869628906,\"y\":197.6529541015625},{\"x\":380.23480224609375,\"y\":197.8413848876953},{\"x\":372.7914123535156,\"y\":198.0200653076172},{\"x\":365.3478698730469,\"y\":198.19000244140625},{\"x\":357.90411376953125,\"y\":198.350341796875},{\"x\":350.4601745605469,\"y\":198.50411987304688},{\"x\":343.01611328125,\"y\":198.65133666992188},{\"x\":335.5719909667969,\"y\":198.79347229003906},{\"x\":328.12774658203125,\"y\":198.93048095703125},{\"x\":320.68353271484375,\"y\":199.0675048828125},{\"x\":313.2392578125,\"y\":199.20228576660156},{\"x\":305.79498291015625,\"y\":199.33682250976562},{\"x\":298.35064697265625,\"y\":199.4713592529297},{\"x\":290.9063720703125,\"y\":199.60589599609375},{\"x\":283.46209716796875,\"y\":199.7404327392578},{\"x\":276.017822265625,\"y\":199.87496948242188}]", + 720, + 480, + 49, + "path", + "basis", + 0.5, + 1, + "list", + 0, + 1, + null, + null + ] + }, + { + "id": 30, + "type": "CogVideoTextEncode", + "pos": { + "0": 493, + "1": 303 + }, + "size": { + "0": 471.90142822265625, + "1": 168.08047485351562 + }, + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 54 + } + ], + "outputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "links": [ + 122 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoTextEncode" + }, + "widgets_values": [ + "a stag is standing and looking around in a forest", + 1, + true + ] + }, + { + "id": 44, + "type": "VHS_VideoCombine", + "pos": { + "0": 2210, + "1": 151 + }, + "size": [ + 1131.619140625, + 1065.0794270833335 + ], + "flags": {}, + "order": 16, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 156 + }, + { + "name": "audio", + "type": "AUDIO", + "link": null, + "shape": 7 + }, + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null, + "shape": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "Filenames", + "type": "VHS_FILENAMES", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_VideoCombine" + }, + "widgets_values": { + "frame_rate": 16, + "loop_count": 0, + "filename_prefix": "CogVideoX-Tora", + "format": "video/h264-mp4", + "pix_fmt": "yuv420p", + "crf": 19, + "save_metadata": true, + "pingpong": false, + "save_output": false, + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "filename": "CogVideoX-Tora_00012.mp4", + "subfolder": "", + "type": "temp", + "format": "video/h264-mp4", + "frame_rate": 16 + }, + "muted": false + } + } + }, + { + "id": 75, + "type": "DownloadAndLoadToraModel", + "pos": { + "0": 253, + "1": 146 + }, + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "tora_model", + "type": "TORAMODEL", + "links": [ + 175 + ] + } + ], + "properties": { + "Node name for S&R": "DownloadAndLoadToraModel" + }, + "widgets_values": [ + "kijai/CogVideoX-5b-Tora" + ] + }, + { + "id": 74, + "type": "ToraEncodeTrajectory", + "pos": { + "0": 1060, + "1": 670 + }, + "size": [ + 335.1993359916705, + 206 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "pipeline", + "type": "COGVIDEOPIPE", + "link": 174 + }, + { + "name": "tora_model", + "type": "TORAMODEL", + "link": 175 + }, + { + "name": "coordinates", + "type": "STRING", + "link": 176, + "widget": { + "name": "coordinates" + } + }, + { + "name": "num_frames", + "type": "INT", + "link": 170, + "widget": { + "name": "num_frames" + } + }, + { + "name": "width", + "type": "INT", + "link": 171, + "widget": { + "name": "width" + } + }, + { + "name": "height", + "type": "INT", + "link": 172, + "widget": { + "name": "height" + } + } + ], + "outputs": [ + { + "name": "tora_trajectory", + "type": "TORAFEATURES", + "links": [ + 173 + ] + }, + { + "name": "video_flow_images", + "type": "IMAGE", + "links": null + } + ], + "properties": { + "Node name for S&R": "ToraEncodeTrajectory" + }, + "widgets_values": [ + "", + 720, + 480, + 49, + 1, + 0, + 0.1 + ] + }, + { + "id": 1, + "type": "DownloadAndLoadCogVideoModel", + "pos": { + "0": 633, + "1": 44 + }, + "size": { + "0": 337.8885192871094, + "1": 194 + }, + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "pab_config", + "type": "PAB_CONFIG", + "link": null, + "shape": 7 + }, + { + "name": "block_edit", + "type": "TRANSFORMERBLOCKS", + "link": null, + "shape": 7 + }, + { + "name": "lora", + "type": "COGLORA", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "cogvideo_pipe", + "type": "COGVIDEOPIPE", + "links": [ + 121, + 164, + 174 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "DownloadAndLoadCogVideoModel" + }, + "widgets_values": [ + "THUDM/CogVideoX-5b-I2V", + "bf16", + "disabled", + "disabled", + false + ] + } + ], + "links": [ + [ + 54, + 20, + 0, + 30, + 0, + "CLIP" + ], + [ + 56, + 20, + 0, + 31, + 0, + "CLIP" + ], + [ + 121, + 1, + 0, + 57, + 0, + "COGVIDEOPIPE" + ], + [ + 122, + 30, + 0, + 57, + 1, + "CONDITIONING" + ], + [ + 123, + 31, + 0, + 57, + 2, + "CONDITIONING" + ], + [ + 127, + 57, + 1, + 56, + 1, + "LATENT" + ], + [ + 128, + 57, + 0, + 56, + 0, + "COGVIDEOPIPE" + ], + [ + 142, + 65, + 0, + 66, + 0, + "IMAGE" + ], + [ + 145, + 60, + 1, + 65, + 0, + "STRING" + ], + [ + 146, + 60, + 0, + 67, + 0, + "MASK" + ], + [ + 149, + 67, + 1, + 65, + 2, + "INT" + ], + [ + 150, + 67, + 2, + 65, + 3, + "INT" + ], + [ + 151, + 67, + 2, + 57, + 9, + "INT" + ], + [ + 152, + 67, + 1, + 57, + 10, + "INT" + ], + [ + 153, + 65, + 0, + 68, + 1, + "IMAGE" + ], + [ + 154, + 65, + 1, + 68, + 2, + "MASK" + ], + [ + 155, + 56, + 0, + 68, + 0, + "IMAGE" + ], + [ + 156, + 68, + 0, + 44, + 0, + "IMAGE" + ], + [ + 157, + 67, + 3, + 57, + 8, + "INT" + ], + [ + 162, + 71, + 0, + 57, + 4, + "LATENT" + ], + [ + 164, + 1, + 0, + 71, + 0, + "COGVIDEOPIPE" + ], + [ + 166, + 72, + 0, + 73, + 0, + "IMAGE" + ], + [ + 167, + 73, + 0, + 71, + 1, + "IMAGE" + ], + [ + 168, + 67, + 1, + 73, + 2, + "INT" + ], + [ + 169, + 67, + 2, + 73, + 3, + "INT" + ], + [ + 170, + 67, + 3, + 74, + 3, + "INT" + ], + [ + 171, + 67, + 1, + 74, + 4, + "INT" + ], + [ + 172, + 67, + 2, + 74, + 5, + "INT" + ], + [ + 173, + 74, + 0, + 57, + 7, + "TORAFEATURES" + ], + [ + 174, + 1, + 0, + 74, + 0, + "COGVIDEOPIPE" + ], + [ + 175, + 75, + 0, + 74, + 1, + "TORAMODEL" + ], + [ + 176, + 60, + 1, + 74, + 2, + "STRING" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.5730855330117661, + "offset": [ + 931.6955110788911, + 94.46846235728464 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/nodes.py b/nodes.py index af71957..4a044db 100644 --- a/nodes.py +++ b/nodes.py @@ -1,5 +1,6 @@ import os import torch +import torch.nn as nn import folder_paths import comfy.model_management as mm from comfy.utils import ProgressBar, load_torch_file @@ -420,38 +421,6 @@ class DownloadAndLoadCogVideoModel: fuse_qkv_projections=True if pab_config is None else False, ) - if "Tora" in model: - import torch.nn as nn - from .tora.traj_module import MGF - - hidden_size = 3072 - num_layers = transformer.num_layers - pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)]) - fuser_sd = load_torch_file(os.path.join(base_path, "fuser", "fuser.safetensors")) - pipe.transformer.fuser_list.load_state_dict(fuser_sd) - for module in transformer.fuser_list: - for param in module.parameters(): - param.data = param.data.to(torch.float16) - del fuser_sd - - from .tora.traj_module import TrajExtractor - traj_extractor = TrajExtractor( - vae_downsize=(4, 8, 8), - patch_size=2, - nums_rb=2, - cin=vae.config.latent_channels, - channels=[128] * transformer.num_layers, - sk=True, - use_conv=False, - ) - - traj_sd = load_torch_file(os.path.join(base_path, "traj_extractor", "traj_extractor.safetensors")) - traj_extractor.load_state_dict(traj_sd) - traj_extractor.to(torch.float32).to(device) - - pipe.traj_extractor = traj_extractor - - pipeline = { "pipe": pipe, "dtype": dtype, @@ -622,63 +591,6 @@ class DownloadAndLoadCogVideoGGUFModel: vae.load_state_dict(vae_sd) pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) - if "Tora" in model: - import torch.nn as nn - from .tora.traj_module import MGF - - download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora") - fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors") - if not os.path.exists(fuser_path): - log.info(f"Downloading Fuser model to: {fuser_path}") - from huggingface_hub import snapshot_download - - snapshot_download( - repo_id="kijai/CogVideoX-5b-Tora", - allow_patterns=["*fuser.safetensors*"], - local_dir=download_path, - local_dir_use_symlinks=False, - ) - - hidden_size = 3072 - num_layers = transformer.num_layers - pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)]) - - fuser_sd = load_torch_file(fuser_path) - pipe.transformer.fuser_list.load_state_dict(fuser_sd) - for module in transformer.fuser_list: - for param in module.parameters(): - param.data = param.data.to(torch.float16) - del fuser_sd - - traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors") - if not os.path.exists(traj_extractor_path): - log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}") - from huggingface_hub import snapshot_download - - snapshot_download( - repo_id="kijai/CogVideoX-5b-Tora", - allow_patterns=["*traj_extractor.safetensors*"], - local_dir=download_path, - local_dir_use_symlinks=False, - ) - - from .tora.traj_module import TrajExtractor - traj_extractor = TrajExtractor( - vae_downsize=(4, 8, 8), - patch_size=2, - nums_rb=2, - cin=vae.config.latent_channels, - channels=[128] * transformer.num_layers, - sk=True, - use_conv=False, - ) - - traj_sd = load_torch_file(traj_extractor_path) - traj_extractor.load_state_dict(traj_sd) - traj_extractor.to(torch.float32).to(device) - - pipe.traj_extractor = traj_extractor - if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() @@ -694,6 +606,114 @@ class DownloadAndLoadCogVideoGGUFModel: return (pipeline,) +class DownloadAndLoadToraModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ( + [ + "kijai/CogVideoX-5b-Tora", + ], + ), + }, + } + + RETURN_TYPES = ("TORAMODEL",) + RETURN_NAMES = ("tora_model", ) + FUNCTION = "loadmodel" + CATEGORY = "CogVideoWrapper" + DESCRIPTION = "Downloads and loads the the Tora model from Huggingface to 'ComfyUI/models/CogVideo/CogVideoX-5b-Tora'" + + def loadmodel(self, model): + + check_diffusers_version() + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + download_path = folder_paths.get_folder_paths("CogVideo")[0] + + from .tora.traj_module import MGF + + try: + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + is_accelerate_available = True + except: + is_accelerate_available = False + pass + + download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora") + fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors") + if not os.path.exists(fuser_path): + log.info(f"Downloading Fuser model to: {fuser_path}") + from huggingface_hub import snapshot_download + + snapshot_download( + repo_id=model, + allow_patterns=["*fuser.safetensors*"], + local_dir=download_path, + local_dir_use_symlinks=False, + ) + + hidden_size = 3072 + num_layers = 42 + + with (init_empty_weights() if is_accelerate_available else nullcontext()): + fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)]) + + fuser_sd = load_torch_file(fuser_path) + if is_accelerate_available: + for key in fuser_sd: + set_module_tensor_to_device(fuser_list, key, dtype=torch.float16, device=device, value=fuser_sd[key]) + else: + fuser_list.load_state_dict(fuser_sd) + for module in fuser_list: + for param in module.parameters(): + param.data = param.data.to(torch.float16).to(device) + del fuser_sd + + traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors") + if not os.path.exists(traj_extractor_path): + log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}") + from huggingface_hub import snapshot_download + + snapshot_download( + repo_id="kijai/CogVideoX-5b-Tora", + allow_patterns=["*traj_extractor.safetensors*"], + local_dir=download_path, + local_dir_use_symlinks=False, + ) + + from .tora.traj_module import TrajExtractor + with (init_empty_weights() if is_accelerate_available else nullcontext()): + traj_extractor = TrajExtractor( + vae_downsize=(4, 8, 8), + patch_size=2, + nums_rb=2, + cin=16, + channels=[128] * 42, + sk=True, + use_conv=False, + ) + + traj_sd = load_torch_file(traj_extractor_path) + if is_accelerate_available: + for key in traj_sd: + set_module_tensor_to_device(traj_extractor, key, dtype=torch.float32, device=device, value=traj_sd[key]) + else: + traj_extractor.load_state_dict(traj_sd) + traj_extractor.to(torch.float32).to(device) + + toramodel = { + "fuser_list": fuser_list, + "traj_extractor": traj_extractor, + } + + return (toramodel,) + class DownloadAndLoadCogVideoControlNet: @classmethod def INPUT_TYPES(s): @@ -1060,11 +1080,14 @@ class ToraEncodeTrajectory: def INPUT_TYPES(s): return {"required": { "pipeline": ("COGVIDEOPIPE",), + "tora_model": ("TORAMODEL",), "coordinates": ("STRING", {"forceInput": True}), "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), "num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } @@ -1073,13 +1096,12 @@ class ToraEncodeTrajectory: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, width, height, num_frames, coordinates, strength): + def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model): check_diffusers_version() device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) - traj_extractor = pipeline["pipe"].traj_extractor vae = pipeline["pipe"].vae vae.enable_slicing() vae._clear_fake_context_parallel_cache() @@ -1108,22 +1130,33 @@ class ToraEncodeTrajectory: video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor vae.to(offload_device) - video_flow_features = traj_extractor(video_flow.to(torch.float32)) + video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32)) video_flow_features = torch.stack(video_flow_features) video_flow_features = video_flow_features * strength logging.info(f"video_flow shape: {video_flow.shape}") - return (video_flow_features, video_flow_image.cpu().float()) + tora = { + "video_flow_features" : video_flow_features, + "start_percent" : start_percent, + "end_percent" : end_percent, + "traj_extractor" : tora_model["traj_extractor"], + "fuser_list" : tora_model["fuser_list"], + } + + return (tora, video_flow_image.cpu().float()) class ToraEncodeOpticalFlow: @classmethod def INPUT_TYPES(s): return {"required": { "pipeline": ("COGVIDEOPIPE",), + "tora_model": ("TORAMODEL",), "optical_flow": ("IMAGE", ), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } @@ -1133,14 +1166,13 @@ class ToraEncodeOpticalFlow: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, optical_flow, strength): + def encode(self, pipeline, optical_flow, strength, tora_model, start_percent, end_percent): check_diffusers_version() B, H, W, C = optical_flow.shape device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) - traj_extractor = pipeline["pipe"].traj_extractor vae = pipeline["pipe"].vae vae.enable_slicing() vae._clear_fake_context_parallel_cache() @@ -1157,14 +1189,22 @@ class ToraEncodeOpticalFlow: video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor vae.to(offload_device) - video_flow_features = traj_extractor(video_flow.to(torch.float32)) + video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32)) video_flow_features = torch.stack(video_flow_features) video_flow_features = video_flow_features * strength logging.info(f"video_flow shape: {video_flow.shape}") - return (video_flow_features, ) + tora = { + "video_flow_features" : video_flow_features, + "start_percent" : start_percent, + "end_percent" : end_percent, + "traj_extractor" : tora_model["traj_extractor"], + "fuser_list" : tora_model["fuser_list"], + } + + return (tora, ) @@ -1227,6 +1267,9 @@ class CogVideoSampler: else: raise ValueError(f"Unknown scheduler: {scheduler}") + if tora_trajectory is not None: + pipe.transformer.fuser_list = tora_trajectory["fuser_list"] + if context_options is not None: context_frames = context_options["context_frames"] // 4 context_stride = context_options["context_stride"] // 4 @@ -1262,7 +1305,7 @@ class CogVideoSampler: context_overlap= context_overlap, freenoise=context_options["freenoise"] if context_options is not None else None, controlnet=controlnet, - video_flow_features=tora_trajectory if tora_trajectory is not None else None, + tora=tora_trajectory if tora_trajectory is not None else None, ) if not pipeline["cpu_offloading"]: pipe.transformer.to(offload_device) @@ -1809,6 +1852,7 @@ NODE_CLASS_MAPPINGS = { "DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet, "ToraEncodeTrajectory": ToraEncodeTrajectory, "ToraEncodeOpticalFlow": ToraEncodeOpticalFlow, + "DownloadAndLoadToraModel": DownloadAndLoadToraModel, } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1831,4 +1875,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet", "ToraEncodeTrajectory": "Tora Encode Trajectory", "ToraEncodeOpticalFlow": "Tora Encode OpticalFlow", + "DownloadAndLoadToraModel": "(Down)load Tora Model", } diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 4310a1d..673f854 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -161,7 +161,6 @@ class CogVideoXPipeline(VideoSysPipeline): self.original_mask = original_mask self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - self.traj_extractor = None if pab_config is not None: set_pab_manager(pab_config) @@ -390,7 +389,7 @@ class CogVideoXPipeline(VideoSysPipeline): context_overlap: Optional[int] = None, freenoise: Optional[bool] = True, controlnet: Optional[dict] = None, - video_flow_features: Optional[torch.Tensor] = None, + tora: Optional[dict] = None, ): """ @@ -582,8 +581,8 @@ class CogVideoXPipeline(VideoSysPipeline): if self.transformer.config.use_rotary_positional_embeddings else None ) - if video_flow_features is not None and do_classifier_free_guidance: - video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous() + if tora is not None and do_classifier_free_guidance: + tora["video_flow_features"] = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous() # 9. Controlnet if controlnet is not None: @@ -784,11 +783,11 @@ class CogVideoXPipeline(VideoSysPipeline): else: for c in context_queue: partial_latent_model_input = latent_model_input[:, c, :, :, :] - if video_flow_features is not None: + if tora is not None: if do_classifier_free_guidance: - partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous() + partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous() else: - partial_video_flow_features = video_flow_features[:, c, :, :, :] + partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :] else: partial_video_flow_features = None @@ -869,7 +868,7 @@ class CogVideoXPipeline(VideoSysPipeline): return_dict=False, controlnet_states=controlnet_states, controlnet_weights=control_weights, - video_flow_features=video_flow_features, + video_flow_features=tora["video_flow_features"] if (tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, )[0] noise_pred = noise_pred.float()