From d76229c49be366677af4db5b57385e84b53f46e0 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 8 Oct 2024 16:22:07 +0300 Subject: [PATCH] controlnet support https://huggingface.co/TheDenk/cogvideox-2b-controlnet-hed-v1 https://huggingface.co/TheDenk/cogvideox-2b-controlnet-canny-v1 --- cogvideo_controlnet.py | 204 ++++ custom_cogvideox_transformer_3d.py | 14 + .../cogvideox_2b_controlnet_example_01.json | 904 ++++++++++++++++++ nodes.py | 94 +- pipeline_cogvideox.py | 38 +- 5 files changed, 1248 insertions(+), 6 deletions(-) create mode 100644 cogvideo_controlnet.py create mode 100644 examples/cogvideox_2b_controlnet_example_01.json diff --git a/cogvideo_controlnet.py b/cogvideo_controlnet.py new file mode 100644 index 0000000..b7f7399 --- /dev/null +++ b/cogvideo_controlnet.py @@ -0,0 +1,204 @@ +# https://github.com/TheDenk/cogvideox-controlnet/blob/main/cogvideo_controlnet.py +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from einops import rearrange +import torch.nn.functional as F +from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock +from diffusers.utils import is_torch_version +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config + + +class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + vae_channels: int = 16, + in_channels: int = 3, + downscale_coef: int = 8, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + num_layers: int = 8, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + start_channels = in_channels * (downscale_coef ** 2) + input_channels = [start_channels, start_channels // 2, start_channels // 4] + self.unshuffle = nn.PixelUnshuffle(downscale_coef) + + self.controlnet_encode_first = nn.Sequential( + nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0), + nn.GroupNorm(2, input_channels[1]), + nn.ReLU(), + ) + + self.controlnet_encode_second = nn.Sequential( + nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0), + nn.GroupNorm(2, input_channels[2]), + nn.ReLU(), + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=vae_channels + input_channels[2], + embed_dim=inner_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + + self.gradient_checkpointing = False + + def compress_time(self, x, num_frames): + x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) + batch_size, frames, channels, height, width = x.shape + x = rearrange(x, 'b f c h w -> (b h w) c f') + + if x.shape[-1] % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + + x = torch.cat([x_first[..., None], x_rest], dim=-1) + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width) + return x + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + controlnet_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + timestep_cond: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size, num_frames, channels, height, width = controlnet_states.shape + # 0. Controlnet encoder + controlnet_states = rearrange(controlnet_states, 'b f c h w -> (b f) c h w') + controlnet_states = self.unshuffle(controlnet_states) + controlnet_states = self.controlnet_encode_first(controlnet_states) + controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames) + num_frames = controlnet_states.shape[0] // batch_size + + controlnet_states = self.controlnet_encode_second(controlnet_states) + controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames) + controlnet_states = rearrange(controlnet_states, '(b f) c h w -> b f c h w', b=batch_size) + + hidden_states = torch.cat([hidden_states, controlnet_states], dim=2) + # controlnet_states = self.controlnext_encoder(controlnet_states, timestep=timestep) + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + + controlnet_hidden_states = () + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + controlnet_hidden_states += (hidden_states,) + + if not return_dict: + return (controlnet_hidden_states,) + return Transformer2DModelOutput(sample=controlnet_hidden_states) \ No newline at end of file diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index f2c27fd..aa6a3fb 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -19,6 +19,8 @@ import torch from torch import nn import torch.nn.functional as F +import numpy as np + from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import is_torch_version, logging from diffusers.utils.torch_utils import maybe_allow_in_graph @@ -566,6 +568,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + controlnet_states: torch.Tensor = None, + controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape @@ -615,6 +619,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): image_rotary_emb=image_rotary_emb, ) + if (controlnet_states is not None) and (i < len(controlnet_states)): + controlnet_states_block = controlnet_states[i] + controlnet_block_weight = 1.0 + if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): + controlnet_block_weight = controlnet_weights[i] + elif isinstance(controlnet_weights, (float, int)): + controlnet_block_weight = controlnet_weights + + hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight + if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B hidden_states = self.norm_final(hidden_states) diff --git a/examples/cogvideox_2b_controlnet_example_01.json b/examples/cogvideox_2b_controlnet_example_01.json new file mode 100644 index 0000000..cd5ccef --- /dev/null +++ b/examples/cogvideox_2b_controlnet_example_01.json @@ -0,0 +1,904 @@ +{ + "last_node_id": 43, + "last_link_id": 77, + "nodes": [ + { + "id": 11, + "type": "CogVideoDecode", + "pos": { + "0": 740, + "1": 580 + }, + "size": { + "0": 300.396484375, + "1": 198 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "pipeline", + "type": "COGVIDEOPIPE", + "link": 63 + }, + { + "name": "samples", + "type": "LATENT", + "link": 64 + } + ], + "outputs": [ + { + "name": "images", + "type": "IMAGE", + "links": [ + 76 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoDecode" + }, + "widgets_values": [ + false, + 240, + 360, + 0.2, + 0.2, + true + ] + }, + { + "id": 41, + "type": "HEDPreprocessor", + "pos": { + "0": -570, + "1": -76 + }, + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 73 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 74 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "HEDPreprocessor" + }, + "widgets_values": [ + "enable", + 768 + ] + }, + { + "id": 31, + "type": "CogVideoTextEncode", + "pos": { + "0": 140, + "1": 660 + }, + "size": { + "0": 463.01251220703125, + "1": 124 + }, + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 56 + } + ], + "outputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "links": [ + 62 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoTextEncode" + }, + "widgets_values": [ + "", + 1, + true + ] + }, + { + "id": 20, + "type": "CLIPLoader", + "pos": { + "0": -390, + "1": 480 + }, + "size": { + "0": 451.30548095703125, + "1": 82 + }, + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 54, + 56 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CLIPLoader" + }, + "widgets_values": [ + "t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors", + "sd3" + ] + }, + { + "id": 38, + "type": "VHS_LoadVideo", + "pos": { + "0": -847, + "1": -78 + }, + "size": [ + 247.455078125, + 427.63671875 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [ + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null, + "shape": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 73 + ], + "slot_index": 0 + }, + { + "name": "frame_count", + "type": "INT", + "links": null + }, + { + "name": "audio", + "type": "AUDIO", + "links": null + }, + { + "name": "video_info", + "type": "VHS_VIDEOINFO", + "links": null + } + ], + "properties": { + "Node name for S&R": "VHS_LoadVideo" + }, + "widgets_values": { + "video": "car.mp4", + "force_rate": 0, + "force_size": "Disabled", + "custom_width": 512, + "custom_height": 512, + "frame_load_cap": 49, + "skip_first_frames": 0, + "select_every_nth": 1, + "choose video to upload": "image", + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "frame_load_cap": 49, + "skip_first_frames": 0, + "force_rate": 0, + "filename": "car.mp4", + "type": "input", + "format": "video/mp4", + "select_every_nth": 1 + }, + "muted": false + } + } + }, + { + "id": 39, + "type": "ImageResizeKJ", + "pos": { + "0": -563, + "1": 63 + }, + "size": { + "0": 315, + "1": 266 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 74 + }, + { + "name": "get_image_size", + "type": "IMAGE", + "link": null, + "shape": 7 + }, + { + "name": "width_input", + "type": "INT", + "link": null, + "widget": { + "name": "width_input" + }, + "shape": 7 + }, + { + "name": "height_input", + "type": "INT", + "link": null, + "widget": { + "name": "height_input" + }, + "shape": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 71 + ], + "slot_index": 0 + }, + { + "name": "width", + "type": "INT", + "links": null + }, + { + "name": "height", + "type": "INT", + "links": null + } + ], + "properties": { + "Node name for S&R": "ImageResizeKJ" + }, + "widgets_values": [ + 720, + 480, + "lanczos", + false, + 2, + 0, + 0, + "disabled" + ] + }, + { + "id": 40, + "type": "GetImageSizeAndCount", + "pos": { + "0": -190, + "1": -68 + }, + "size": { + "0": 277.20001220703125, + "1": 86 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 71 + } + ], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 72, + 75 + ], + "slot_index": 0 + }, + { + "name": "720 width", + "type": "INT", + "links": null + }, + { + "name": "480 height", + "type": "INT", + "links": null + }, + { + "name": "49 count", + "type": "INT", + "links": null + } + ], + "properties": { + "Node name for S&R": "GetImageSizeAndCount" + }, + "widgets_values": [] + }, + { + "id": 37, + "type": "CogVideoControlNet", + "pos": { + "0": 133, + "1": 131 + }, + "size": { + "0": 367.79998779296875, + "1": 126 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "controlnet", + "type": "COGVIDECONTROLNETMODEL", + "link": 67 + }, + { + "name": "images", + "type": "IMAGE", + "link": 72 + } + ], + "outputs": [ + { + "name": "cogvideo_controlnet", + "type": "COGVIDECONTROLNET", + "links": [ + 68 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CogVideoControlNet" + }, + "widgets_values": [ + 1, + 0, + 1 + ] + }, + { + "id": 35, + "type": "DownloadAndLoadCogVideoControlNet", + "pos": { + "0": -187, + "1": -207 + }, + "size": { + "0": 378, + "1": 58 + }, + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "cogvideo_controlnet", + "type": "COGVIDECONTROLNETMODEL", + "links": [ + 67 + ] + } + ], + "properties": { + "Node name for S&R": "DownloadAndLoadCogVideoControlNet" + }, + "widgets_values": [ + "TheDenk/cogvideox-2b-controlnet-hed-v1" + ] + }, + { + "id": 1, + "type": "DownloadAndLoadCogVideoModel", + "pos": { + "0": -157, + "1": -473 + }, + "size": { + "0": 315, + "1": 194 + }, + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [ + { + "name": "pab_config", + "type": "PAB_CONFIG", + "link": null, + "shape": 7 + }, + { + "name": "block_edit", + "type": "TRANSFORMERBLOCKS", + "link": null, + "shape": 7 + }, + { + "name": "lora", + "type": "COGLORA", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "cogvideo_pipe", + "type": "COGVIDEOPIPE", + "links": [ + 60 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "DownloadAndLoadCogVideoModel" + }, + "widgets_values": [ + "THUDM/CogVideoX-2b", + "bf16", + "disabled", + "disabled", + false + ] + }, + { + "id": 30, + "type": "CogVideoTextEncode", + "pos": { + "0": 130, + "1": 350 + }, + "size": [ + 475.7874994452536, + 231.2989729014987 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 54 + } + ], + "outputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "links": [ + 61 + ], + "slot_index": 0, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoTextEncode" + }, + "widgets_values": [ + "car is moving among mountains", + 1, + true + ] + }, + { + "id": 34, + "type": "CogVideoSampler", + "pos": { + "0": 730, + "1": 170 + }, + "size": { + "0": 315.8404846191406, + "1": 370 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "pipeline", + "type": "COGVIDEOPIPE", + "link": 60 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 61 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 62 + }, + { + "name": "samples", + "type": "LATENT", + "link": null, + "shape": 7 + }, + { + "name": "image_cond_latents", + "type": "LATENT", + "link": null, + "shape": 7 + }, + { + "name": "context_options", + "type": "COGCONTEXT", + "link": null, + "shape": 7 + }, + { + "name": "controlnet", + "type": "COGVIDECONTROLNET", + "link": 68, + "shape": 7 + } + ], + "outputs": [ + { + "name": "cogvideo_pipe", + "type": "COGVIDEOPIPE", + "links": [ + 63 + ], + "shape": 3 + }, + { + "name": "samples", + "type": "LATENT", + "links": [ + 64 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "CogVideoSampler" + }, + "widgets_values": [ + 480, + 720, + 49, + 32, + 6, + 806286757407563, + "fixed", + "CogVideoXDDIM", + 1 + ] + }, + { + "id": 42, + "type": "ImageConcatMulti", + "pos": { + "0": 1139, + "1": -19 + }, + "size": { + "0": 210, + "1": 150 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "image_1", + "type": "IMAGE", + "link": 75 + }, + { + "name": "image_2", + "type": "IMAGE", + "link": 76 + } + ], + "outputs": [ + { + "name": "images", + "type": "IMAGE", + "links": [ + 77 + ], + "slot_index": 0 + } + ], + "properties": {}, + "widgets_values": [ + 2, + "right", + false, + null + ] + }, + { + "id": 43, + "type": "VHS_VideoCombine", + "pos": { + "0": 1154, + "1": 202 + }, + "size": [ + 778.7022705078125, + 576.9007568359375 + ], + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 77 + }, + { + "name": "audio", + "type": "AUDIO", + "link": null, + "shape": 7 + }, + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null, + "shape": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "Filenames", + "type": "VHS_FILENAMES", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_VideoCombine" + }, + "widgets_values": { + "frame_rate": 8, + "loop_count": 0, + "filename_prefix": "CogVideoX5B", + "format": "video/h264-mp4", + "pix_fmt": "yuv420p", + "crf": 19, + "save_metadata": true, + "pingpong": false, + "save_output": false, + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "filename": "CogVideoX5B_00007.mp4", + "subfolder": "", + "type": "temp", + "format": "video/h264-mp4", + "frame_rate": 8 + }, + "muted": false + } + } + } + ], + "links": [ + [ + 54, + 20, + 0, + 30, + 0, + "CLIP" + ], + [ + 56, + 20, + 0, + 31, + 0, + "CLIP" + ], + [ + 60, + 1, + 0, + 34, + 0, + "COGVIDEOPIPE" + ], + [ + 61, + 30, + 0, + 34, + 1, + "CONDITIONING" + ], + [ + 62, + 31, + 0, + 34, + 2, + "CONDITIONING" + ], + [ + 63, + 34, + 0, + 11, + 0, + "COGVIDEOPIPE" + ], + [ + 64, + 34, + 1, + 11, + 1, + "LATENT" + ], + [ + 67, + 35, + 0, + 37, + 0, + "COGVIDECONTROLNETMODEL" + ], + [ + 68, + 37, + 0, + 34, + 6, + "COGVIDECONTROLNET" + ], + [ + 71, + 39, + 0, + 40, + 0, + "IMAGE" + ], + [ + 72, + 40, + 0, + 37, + 1, + "IMAGE" + ], + [ + 73, + 38, + 0, + 41, + 0, + "IMAGE" + ], + [ + 74, + 41, + 0, + 39, + 0, + "IMAGE" + ], + [ + 75, + 40, + 0, + 42, + 0, + "IMAGE" + ], + [ + 76, + 11, + 0, + 42, + 1, + "IMAGE" + ], + [ + 77, + 42, + 0, + 43, + 0, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.6303940863129801, + "offset": [ + 1194.8126582413695, + 661.2034019206458 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/nodes.py b/nodes.py index 8845fe9..0d06eea 100644 --- a/nodes.py +++ b/nodes.py @@ -594,6 +594,52 @@ class DownloadAndLoadCogVideoGGUFModel: } return (pipeline,) + +class DownloadAndLoadCogVideoControlNet: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ( + [ + "TheDenk/cogvideox-2b-controlnet-hed-v1", + "TheDenk/cogvideox-2b-controlnet-canny-v1", + ], + ), + + }, + } + + RETURN_TYPES = ("COGVIDECONTROLNETMODEL",) + RETURN_NAMES = ("cogvideo_controlnet", ) + FUNCTION = "loadmodel" + CATEGORY = "CogVideoWrapper" + + def loadmodel(self, model): + from .cogvideo_controlnet import CogVideoXControlnet + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + + download_path = os.path.join(folder_paths.models_dir, 'CogVideo', 'ControlNet') + base_path = os.path.join(download_path, (model.split("/")[-1])) + + if not os.path.exists(base_path): + log.info(f"Downloading model to: {base_path}") + from huggingface_hub import snapshot_download + + snapshot_download( + repo_id=model, + ignore_patterns=["*text_encoder*", "*tokenizer*"], + local_dir=base_path, + local_dir_use_symlinks=False, + ) + + controlnet = CogVideoXControlnet.from_pretrained(base_path) + + return (controlnet,) class CogVideoEncodePrompt: @classmethod @@ -855,6 +901,7 @@ class CogVideoSampler: "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "image_cond_latents": ("LATENT", ), "context_options": ("COGCONTEXT", ), + "controlnet": ("COGVIDECONTROLNET",), } } @@ -864,7 +911,7 @@ class CogVideoSampler: CATEGORY = "CogVideoWrapper" def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None, - denoise_strength=1.0, image_cond_latents=None, context_options=None): + denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None): mm.soft_empty_cache() base_path = pipeline["base_path"] @@ -921,7 +968,8 @@ class CogVideoSampler: context_frames=context_frames, context_stride= context_stride, context_overlap= context_overlap, - freenoise=context_options["freenoise"] if context_options is not None else None + freenoise=context_options["freenoise"] if context_options is not None else None, + controlnet=controlnet ) if not pipeline["cpu_offloading"]: pipe.transformer.to(offload_device) @@ -1281,6 +1329,41 @@ class CogVideoControlImageEncode: } return (control_latents, width, height) + +class CogVideoControlNet: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "controlnet": ("COGVIDECONTROLNETMODEL",), + "images": ("IMAGE", ), + "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ("COGVIDECONTROLNET",) + RETURN_NAMES = ("cogvideo_controlnet",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, controlnet, images, control_strength, control_start_percent, control_end_percent): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + B, H, W, C = images.shape + + control_frames = images.permute(0, 3, 1, 2).unsqueeze(0) * 2 - 1 + + controlnet = { + "control_model": controlnet, + "control_frames": control_frames, + "control_strength": control_strength, + "control_start": control_start_percent, + "control_end": control_end_percent, + } + + return (controlnet,) class CogVideoContextOptions: @@ -1427,7 +1510,9 @@ NODE_CLASS_MAPPINGS = { "CogVideoTransformerEdit": CogVideoTransformerEdit, "CogVideoControlImageEncode": CogVideoControlImageEncode, "CogVideoLoraSelect": CogVideoLoraSelect, - "CogVideoContextOptions": CogVideoContextOptions + "CogVideoContextOptions": CogVideoContextOptions, + "CogVideoControlNet": CogVideoControlNet, + "DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1445,5 +1530,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoTransformerEdit": "CogVideo TransformerEdit", "CogVideoControlImageEncode": "CogVideo Control ImageEncode", "CogVideoLoraSelect": "CogVideo LoraSelect", - "CogVideoContextOptions": "CogVideo Context Options" + "CogVideoContextOptions": "CogVideo Context Options", + "DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet" } diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 6b9e909..10d3425 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -387,6 +387,8 @@ class CogVideoXPipeline(VideoSysPipeline): context_stride: Optional[int] = None, context_overlap: Optional[int] = None, freenoise: Optional[bool] = True, + controlnet: Optional[dict] = None, + ): """ Function invoked when calling the pipeline for generation. @@ -536,7 +538,7 @@ class CogVideoXPipeline(VideoSysPipeline): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) comfy_pbar = ProgressBar(num_inference_steps) - # 8.5. Temporal tiling prep + # 8. context schedule and temporal tiling if context_schedule is not None and context_schedule == "temporal_tiling": t_tile_length = context_frames t_tile_overlap = context_overlap @@ -562,7 +564,17 @@ class CogVideoXPipeline(VideoSysPipeline): if self.transformer.config.use_rotary_positional_embeddings else None ) - + # 9. Controlnet + + if controlnet is not None: + self.controlnet = controlnet["control_model"].to(device) + control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous() + control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames + control_strength = controlnet["control_strength"] + control_start = controlnet["control_start"] + control_end = controlnet["control_end"] + + # 10. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: old_pred_original_sample = None # for DPM-solver++ for i, t in enumerate(timesteps): @@ -744,6 +756,26 @@ class CogVideoXPipeline(VideoSysPipeline): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + + current_sampling_percent = i / len(timesteps) + + controlnet_states = None + if (control_start < current_sampling_percent < control_end): + # extract controlnet hidden state + controlnet_states = self.controlnet( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + image_rotary_emb=image_rotary_emb, + controlnet_states=control_frames, + timestep=timestep, + return_dict=False, + )[0] + if isinstance(controlnet_states, (tuple, list)): + controlnet_states = [x.to(dtype=self.transformer.dtype) for x in controlnet_states] + else: + controlnet_states = controlnet_states.to(dtype=self.transformer.dtype) + + # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input, @@ -751,6 +783,8 @@ class CogVideoXPipeline(VideoSysPipeline): timestep=timestep, image_rotary_emb=image_rotary_emb, return_dict=False, + controlnet_states=controlnet_states, + controlnet_weights=control_strength, )[0] noise_pred = noise_pred.float()