kijai 2024-10-08 16:22:07 +03:00
parent d83f64aa3e
commit d76229c49b
5 changed files with 1248 additions and 6 deletions

204
cogvideo_controlnet.py Normal file
View File

@ -0,0 +1,204 @@
# https://github.com/TheDenk/cogvideox-controlnet/blob/main/cogvideo_controlnet.py
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from einops import rearrange
import torch.nn.functional as F
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
from diffusers.utils import is_torch_version
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
vae_channels: int = 16,
in_channels: int = 3,
downscale_coef: int = 8,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
num_layers: int = 8,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
raise ValueError(
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
"issue at https://github.com/huggingface/diffusers/issues."
)
start_channels = in_channels * (downscale_coef ** 2)
input_channels = [start_channels, start_channels // 2, start_channels // 4]
self.unshuffle = nn.PixelUnshuffle(downscale_coef)
self.controlnet_encode_first = nn.Sequential(
nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0),
nn.GroupNorm(2, input_channels[1]),
nn.ReLU(),
)
self.controlnet_encode_second = nn.Sequential(
nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0),
nn.GroupNorm(2, input_channels[2]),
nn.ReLU(),
)
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=vae_channels + input_channels[2],
embed_dim=inner_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)
# 2. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.gradient_checkpointing = False
def compress_time(self, x, num_frames):
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
batch_size, frames, channels, height, width = x.shape
x = rearrange(x, 'b f c h w -> (b h w) c f')
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
else:
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width)
return x
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
controlnet_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = controlnet_states.shape
# 0. Controlnet encoder
controlnet_states = rearrange(controlnet_states, 'b f c h w -> (b f) c h w')
controlnet_states = self.unshuffle(controlnet_states)
controlnet_states = self.controlnet_encode_first(controlnet_states)
controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames)
num_frames = controlnet_states.shape[0] // batch_size
controlnet_states = self.controlnet_encode_second(controlnet_states)
controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames)
controlnet_states = rearrange(controlnet_states, '(b f) c h w -> b f c h w', b=batch_size)
hidden_states = torch.cat([hidden_states, controlnet_states], dim=2)
# controlnet_states = self.controlnext_encoder(controlnet_states, timestep=timestep)
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
controlnet_hidden_states = ()
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
controlnet_hidden_states += (hidden_states,)
if not return_dict:
return (controlnet_hidden_states,)
return Transformer2DModelOutput(sample=controlnet_hidden_states)

View File

@ -19,6 +19,8 @@ import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
@ -566,6 +568,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
@ -615,6 +619,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
image_rotary_emb=image_rotary_emb,
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)

View File

@ -0,0 +1,904 @@
{
"last_node_id": 43,
"last_link_id": 77,
"nodes": [
{
"id": 11,
"type": "CogVideoDecode",
"pos": {
"0": 740,
"1": 580
},
"size": {
"0": 300.396484375,
"1": 198
},
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 63
},
{
"name": "samples",
"type": "LATENT",
"link": 64
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"links": [
76
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoDecode"
},
"widgets_values": [
false,
240,
360,
0.2,
0.2,
true
]
},
{
"id": 41,
"type": "HEDPreprocessor",
"pos": {
"0": -570,
"1": -76
},
"size": {
"0": 315,
"1": 82
},
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 73
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
74
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "HEDPreprocessor"
},
"widgets_values": [
"enable",
768
]
},
{
"id": 31,
"type": "CogVideoTextEncode",
"pos": {
"0": 140,
"1": 660
},
"size": {
"0": 463.01251220703125,
"1": 124
},
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 56
}
],
"outputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"links": [
62
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"",
1,
true
]
},
{
"id": 20,
"type": "CLIPLoader",
"pos": {
"0": -390,
"1": 480
},
"size": {
"0": 451.30548095703125,
"1": 82
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
54,
56
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CLIPLoader"
},
"widgets_values": [
"t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors",
"sd3"
]
},
{
"id": 38,
"type": "VHS_LoadVideo",
"pos": {
"0": -847,
"1": -78
},
"size": [
247.455078125,
427.63671875
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"name": "meta_batch",
"type": "VHS_BatchManager",
"link": null,
"shape": 7
},
{
"name": "vae",
"type": "VAE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
73
],
"slot_index": 0
},
{
"name": "frame_count",
"type": "INT",
"links": null
},
{
"name": "audio",
"type": "AUDIO",
"links": null
},
{
"name": "video_info",
"type": "VHS_VIDEOINFO",
"links": null
}
],
"properties": {
"Node name for S&R": "VHS_LoadVideo"
},
"widgets_values": {
"video": "car.mp4",
"force_rate": 0,
"force_size": "Disabled",
"custom_width": 512,
"custom_height": 512,
"frame_load_cap": 49,
"skip_first_frames": 0,
"select_every_nth": 1,
"choose video to upload": "image",
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"frame_load_cap": 49,
"skip_first_frames": 0,
"force_rate": 0,
"filename": "car.mp4",
"type": "input",
"format": "video/mp4",
"select_every_nth": 1
},
"muted": false
}
}
},
{
"id": 39,
"type": "ImageResizeKJ",
"pos": {
"0": -563,
"1": 63
},
"size": {
"0": 315,
"1": 266
},
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 74
},
{
"name": "get_image_size",
"type": "IMAGE",
"link": null,
"shape": 7
},
{
"name": "width_input",
"type": "INT",
"link": null,
"widget": {
"name": "width_input"
},
"shape": 7
},
{
"name": "height_input",
"type": "INT",
"link": null,
"widget": {
"name": "height_input"
},
"shape": 7
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
71
],
"slot_index": 0
},
{
"name": "width",
"type": "INT",
"links": null
},
{
"name": "height",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "ImageResizeKJ"
},
"widgets_values": [
720,
480,
"lanczos",
false,
2,
0,
0,
"disabled"
]
},
{
"id": 40,
"type": "GetImageSizeAndCount",
"pos": {
"0": -190,
"1": -68
},
"size": {
"0": 277.20001220703125,
"1": 86
},
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 71
}
],
"outputs": [
{
"name": "image",
"type": "IMAGE",
"links": [
72,
75
],
"slot_index": 0
},
{
"name": "720 width",
"type": "INT",
"links": null
},
{
"name": "480 height",
"type": "INT",
"links": null
},
{
"name": "49 count",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetImageSizeAndCount"
},
"widgets_values": []
},
{
"id": 37,
"type": "CogVideoControlNet",
"pos": {
"0": 133,
"1": 131
},
"size": {
"0": 367.79998779296875,
"1": 126
},
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "controlnet",
"type": "COGVIDECONTROLNETMODEL",
"link": 67
},
{
"name": "images",
"type": "IMAGE",
"link": 72
}
],
"outputs": [
{
"name": "cogvideo_controlnet",
"type": "COGVIDECONTROLNET",
"links": [
68
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "CogVideoControlNet"
},
"widgets_values": [
1,
0,
1
]
},
{
"id": 35,
"type": "DownloadAndLoadCogVideoControlNet",
"pos": {
"0": -187,
"1": -207
},
"size": {
"0": 378,
"1": 58
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "cogvideo_controlnet",
"type": "COGVIDECONTROLNETMODEL",
"links": [
67
]
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoControlNet"
},
"widgets_values": [
"TheDenk/cogvideox-2b-controlnet-hed-v1"
]
},
{
"id": 1,
"type": "DownloadAndLoadCogVideoModel",
"pos": {
"0": -157,
"1": -473
},
"size": {
"0": 315,
"1": 194
},
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "pab_config",
"type": "PAB_CONFIG",
"link": null,
"shape": 7
},
{
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"link": null,
"shape": 7
},
{
"name": "lora",
"type": "COGLORA",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"links": [
60
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoModel"
},
"widgets_values": [
"THUDM/CogVideoX-2b",
"bf16",
"disabled",
"disabled",
false
]
},
{
"id": 30,
"type": "CogVideoTextEncode",
"pos": {
"0": 130,
"1": 350
},
"size": [
475.7874994452536,
231.2989729014987
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 54
}
],
"outputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"links": [
61
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"car is moving among mountains",
1,
true
]
},
{
"id": 34,
"type": "CogVideoSampler",
"pos": {
"0": 730,
"1": 170
},
"size": {
"0": 315.8404846191406,
"1": 370
},
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 60
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 61
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 62
},
{
"name": "samples",
"type": "LATENT",
"link": null,
"shape": 7
},
{
"name": "image_cond_latents",
"type": "LATENT",
"link": null,
"shape": 7
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null,
"shape": 7
},
{
"name": "controlnet",
"type": "COGVIDECONTROLNET",
"link": 68,
"shape": 7
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"links": [
63
],
"shape": 3
},
{
"name": "samples",
"type": "LATENT",
"links": [
64
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoSampler"
},
"widgets_values": [
480,
720,
49,
32,
6,
806286757407563,
"fixed",
"CogVideoXDDIM",
1
]
},
{
"id": 42,
"type": "ImageConcatMulti",
"pos": {
"0": 1139,
"1": -19
},
"size": {
"0": 210,
"1": 150
},
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "image_1",
"type": "IMAGE",
"link": 75
},
{
"name": "image_2",
"type": "IMAGE",
"link": 76
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"links": [
77
],
"slot_index": 0
}
],
"properties": {},
"widgets_values": [
2,
"right",
false,
null
]
},
{
"id": 43,
"type": "VHS_VideoCombine",
"pos": {
"0": 1154,
"1": 202
},
"size": [
778.7022705078125,
576.9007568359375
],
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 77
},
{
"name": "audio",
"type": "AUDIO",
"link": null,
"shape": 7
},
{
"name": "meta_batch",
"type": "VHS_BatchManager",
"link": null,
"shape": 7
},
{
"name": "vae",
"type": "VAE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 8,
"loop_count": 0,
"filename_prefix": "CogVideoX5B",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"pingpong": false,
"save_output": false,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "CogVideoX5B_00007.mp4",
"subfolder": "",
"type": "temp",
"format": "video/h264-mp4",
"frame_rate": 8
},
"muted": false
}
}
}
],
"links": [
[
54,
20,
0,
30,
0,
"CLIP"
],
[
56,
20,
0,
31,
0,
"CLIP"
],
[
60,
1,
0,
34,
0,
"COGVIDEOPIPE"
],
[
61,
30,
0,
34,
1,
"CONDITIONING"
],
[
62,
31,
0,
34,
2,
"CONDITIONING"
],
[
63,
34,
0,
11,
0,
"COGVIDEOPIPE"
],
[
64,
34,
1,
11,
1,
"LATENT"
],
[
67,
35,
0,
37,
0,
"COGVIDECONTROLNETMODEL"
],
[
68,
37,
0,
34,
6,
"COGVIDECONTROLNET"
],
[
71,
39,
0,
40,
0,
"IMAGE"
],
[
72,
40,
0,
37,
1,
"IMAGE"
],
[
73,
38,
0,
41,
0,
"IMAGE"
],
[
74,
41,
0,
39,
0,
"IMAGE"
],
[
75,
40,
0,
42,
0,
"IMAGE"
],
[
76,
11,
0,
42,
1,
"IMAGE"
],
[
77,
42,
0,
43,
0,
"IMAGE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.6303940863129801,
"offset": [
1194.8126582413695,
661.2034019206458
]
}
},
"version": 0.4
}

View File

@ -594,6 +594,52 @@ class DownloadAndLoadCogVideoGGUFModel:
}
return (pipeline,)
class DownloadAndLoadCogVideoControlNet:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (
[
"TheDenk/cogvideox-2b-controlnet-hed-v1",
"TheDenk/cogvideox-2b-controlnet-canny-v1",
],
),
},
}
RETURN_TYPES = ("COGVIDECONTROLNETMODEL",)
RETURN_NAMES = ("cogvideo_controlnet", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model):
from .cogvideo_controlnet import CogVideoXControlnet
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', 'ControlNet')
base_path = os.path.join(download_path, (model.split("/")[-1]))
if not os.path.exists(base_path):
log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=model,
ignore_patterns=["*text_encoder*", "*tokenizer*"],
local_dir=base_path,
local_dir_use_symlinks=False,
)
controlnet = CogVideoXControlnet.from_pretrained(base_path)
return (controlnet,)
class CogVideoEncodePrompt:
@classmethod
@ -855,6 +901,7 @@ class CogVideoSampler:
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"image_cond_latents": ("LATENT", ),
"context_options": ("COGCONTEXT", ),
"controlnet": ("COGVIDECONTROLNET",),
}
}
@ -864,7 +911,7 @@ class CogVideoSampler:
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None):
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None):
mm.soft_empty_cache()
base_path = pipeline["base_path"]
@ -921,7 +968,8 @@ class CogVideoSampler:
context_frames=context_frames,
context_stride= context_stride,
context_overlap= context_overlap,
freenoise=context_options["freenoise"] if context_options is not None else None
freenoise=context_options["freenoise"] if context_options is not None else None,
controlnet=controlnet
)
if not pipeline["cpu_offloading"]:
pipe.transformer.to(offload_device)
@ -1281,6 +1329,41 @@ class CogVideoControlImageEncode:
}
return (control_latents, width, height)
class CogVideoControlNet:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"controlnet": ("COGVIDECONTROLNETMODEL",),
"images": ("IMAGE", ),
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = ("COGVIDECONTROLNET",)
RETURN_NAMES = ("cogvideo_controlnet",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, controlnet, images, control_strength, control_start_percent, control_end_percent):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
B, H, W, C = images.shape
control_frames = images.permute(0, 3, 1, 2).unsqueeze(0) * 2 - 1
controlnet = {
"control_model": controlnet,
"control_frames": control_frames,
"control_strength": control_strength,
"control_start": control_start_percent,
"control_end": control_end_percent,
}
return (controlnet,)
class CogVideoContextOptions:
@ -1427,7 +1510,9 @@ NODE_CLASS_MAPPINGS = {
"CogVideoTransformerEdit": CogVideoTransformerEdit,
"CogVideoControlImageEncode": CogVideoControlImageEncode,
"CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoContextOptions": CogVideoContextOptions
"CogVideoContextOptions": CogVideoContextOptions,
"CogVideoControlNet": CogVideoControlNet,
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1445,5 +1530,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoContextOptions": "CogVideo Context Options"
"CogVideoContextOptions": "CogVideo Context Options",
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet"
}

View File

@ -387,6 +387,8 @@ class CogVideoXPipeline(VideoSysPipeline):
context_stride: Optional[int] = None,
context_overlap: Optional[int] = None,
freenoise: Optional[bool] = True,
controlnet: Optional[dict] = None,
):
"""
Function invoked when calling the pipeline for generation.
@ -536,7 +538,7 @@ class CogVideoXPipeline(VideoSysPipeline):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
comfy_pbar = ProgressBar(num_inference_steps)
# 8.5. Temporal tiling prep
# 8. context schedule and temporal tiling
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames
t_tile_overlap = context_overlap
@ -562,7 +564,17 @@ class CogVideoXPipeline(VideoSysPipeline):
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 9. Controlnet
if controlnet is not None:
self.controlnet = controlnet["control_model"].to(device)
control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous()
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
control_strength = controlnet["control_strength"]
control_start = controlnet["control_start"]
control_end = controlnet["control_end"]
# 10. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
old_pred_original_sample = None # for DPM-solver++
for i, t in enumerate(timesteps):
@ -744,6 +756,26 @@ class CogVideoXPipeline(VideoSysPipeline):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
current_sampling_percent = i / len(timesteps)
controlnet_states = None
if (control_start < current_sampling_percent < control_end):
# extract controlnet hidden state
controlnet_states = self.controlnet(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
image_rotary_emb=image_rotary_emb,
controlnet_states=control_frames,
timestep=timestep,
return_dict=False,
)[0]
if isinstance(controlnet_states, (tuple, list)):
controlnet_states = [x.to(dtype=self.transformer.dtype) for x in controlnet_states]
else:
controlnet_states = controlnet_states.to(dtype=self.transformer.dtype)
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
@ -751,6 +783,8 @@ class CogVideoXPipeline(VideoSysPipeline):
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
controlnet_states=controlnet_states,
controlnet_weights=control_strength,
)[0]
noise_pred = noise_pred.float()