diff --git a/.gitignore b/.gitignore
index da24bdf..d75870d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,4 +7,5 @@ master_ip
logs/
*.DS_Store
.idea
-*.pt
\ No newline at end of file
+*.pt
+tools/
\ No newline at end of file
diff --git a/cogvideox_fun/autoencoder_magvit.py b/cogvideox_fun/autoencoder_magvit.py
deleted file mode 100644
index 9c2b906..0000000
--- a/cogvideox_fun/autoencoder_magvit.py
+++ /dev/null
@@ -1,1296 +0,0 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.loaders.single_file_model import FromOriginalModelMixin
-from diffusers.utils import logging
-from diffusers.utils.accelerate_utils import apply_forward_hook
-from diffusers.models.activations import get_activation
-from diffusers.models.downsampling import CogVideoXDownsample3D
-from diffusers.models.modeling_outputs import AutoencoderKLOutput
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.models.upsampling import CogVideoXUpsample3D
-from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-class CogVideoXSafeConv3d(nn.Conv3d):
- r"""
- A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
- """
-
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
-
- # Set to 2GB, suitable for CuDNN
- if memory_count > 2:
- kernel_size = self.kernel_size[0]
- part_num = int(memory_count / 2) + 1
- input_chunks = torch.chunk(input, part_num, dim=2)
-
- if kernel_size > 1:
- input_chunks = [input_chunks[0]] + [
- torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
- for i in range(1, len(input_chunks))
- ]
-
- output_chunks = []
- for input_chunk in input_chunks:
- output_chunks.append(super().forward(input_chunk))
- output = torch.cat(output_chunks, dim=2)
- return output
- else:
- return super().forward(input)
-
-
-class CogVideoXCausalConv3d(nn.Module):
- r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
-
- Args:
- in_channels (`int`): Number of channels in the input tensor.
- out_channels (`int`): Number of output channels produced by the convolution.
- kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
- stride (`int`, defaults to `1`): Stride of the convolution.
- dilation (`int`, defaults to `1`): Dilation rate of the convolution.
- pad_mode (`str`, defaults to `"constant"`): Padding mode.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int, int]],
- stride: int = 1,
- dilation: int = 1,
- pad_mode: str = "constant",
- ):
- super().__init__()
-
- if isinstance(kernel_size, int):
- kernel_size = (kernel_size,) * 3
-
- time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
-
- self.pad_mode = pad_mode
- time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
- height_pad = height_kernel_size // 2
- width_pad = width_kernel_size // 2
-
- self.height_pad = height_pad
- self.width_pad = width_pad
- self.time_pad = time_pad
- self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
-
- self.temporal_dim = 2
- self.time_kernel_size = time_kernel_size
-
- stride = (stride, 1, 1)
- dilation = (dilation, 1, 1)
- self.conv = CogVideoXSafeConv3d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- )
-
- self.conv_cache = None
-
- def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
- kernel_size = self.time_kernel_size
- if kernel_size > 1:
- cached_inputs = (
- [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
- )
- inputs = torch.cat(cached_inputs + [inputs], dim=2)
- return inputs
-
- def _clear_fake_context_parallel_cache(self):
- del self.conv_cache
- self.conv_cache = None
-
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- inputs = self.fake_context_parallel_forward(inputs)
-
- self._clear_fake_context_parallel_cache()
- # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
- # hundred megabytes and so let's not do it for now
- self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
-
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
-
- output = self.conv(inputs)
- return output
-
-
-class CogVideoXSpatialNorm3D(nn.Module):
- r"""
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
- to 3D-video like data.
-
- CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
-
- Args:
- f_channels (`int`):
- The number of channels for input to group normalization layer, and output of the spatial norm layer.
- zq_channels (`int`):
- The number of channels for the quantized vector as described in the paper.
- groups (`int`):
- Number of groups to separate the channels into for group normalization.
- """
-
- def __init__(
- self,
- f_channels: int,
- zq_channels: int,
- groups: int = 32,
- ):
- super().__init__()
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
- self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
- self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
-
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
- if f.shape[2] > 1 and f.shape[2] % 2 == 1:
- f_first, f_rest = f[:, :, :1], f[:, :, 1:]
- f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
- z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
- z_first = F.interpolate(z_first, size=f_first_size)
- z_rest = F.interpolate(z_rest, size=f_rest_size)
- zq = torch.cat([z_first, z_rest], dim=2)
- else:
- zq = F.interpolate(zq, size=f.shape[-3:])
-
- norm_f = self.norm_layer(f)
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
- return new_f
-
-
-class CogVideoXResnetBlock3D(nn.Module):
- r"""
- A 3D ResNet block used in the CogVideoX model.
-
- Args:
- in_channels (`int`):
- Number of input channels.
- out_channels (`int`, *optional*):
- Number of output channels. If None, defaults to `in_channels`.
- dropout (`float`, defaults to `0.0`):
- Dropout rate.
- temb_channels (`int`, defaults to `512`):
- Number of time embedding channels.
- groups (`int`, defaults to `32`):
- Number of groups to separate the channels into for group normalization.
- eps (`float`, defaults to `1e-6`):
- Epsilon value for normalization layers.
- non_linearity (`str`, defaults to `"swish"`):
- Activation function to use.
- conv_shortcut (bool, defaults to `False`):
- Whether or not to use a convolution shortcut.
- spatial_norm_dim (`int`, *optional*):
- The dimension to use for spatial norm if it is to be used instead of group norm.
- pad_mode (str, defaults to `"first"`):
- Padding mode.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: Optional[int] = None,
- dropout: float = 0.0,
- temb_channels: int = 512,
- groups: int = 32,
- eps: float = 1e-6,
- non_linearity: str = "swish",
- conv_shortcut: bool = False,
- spatial_norm_dim: Optional[int] = None,
- pad_mode: str = "first",
- ):
- super().__init__()
-
- out_channels = out_channels or in_channels
-
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.nonlinearity = get_activation(non_linearity)
- self.use_conv_shortcut = conv_shortcut
-
- if spatial_norm_dim is None:
- self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
- self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
- else:
- self.norm1 = CogVideoXSpatialNorm3D(
- f_channels=in_channels,
- zq_channels=spatial_norm_dim,
- groups=groups,
- )
- self.norm2 = CogVideoXSpatialNorm3D(
- f_channels=out_channels,
- zq_channels=spatial_norm_dim,
- groups=groups,
- )
-
- self.conv1 = CogVideoXCausalConv3d(
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
- )
-
- if temb_channels > 0:
- self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
-
- self.dropout = nn.Dropout(dropout)
- self.conv2 = CogVideoXCausalConv3d(
- in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
- )
-
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = CogVideoXCausalConv3d(
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
- )
- else:
- self.conv_shortcut = CogVideoXSafeConv3d(
- in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(
- self,
- inputs: torch.Tensor,
- temb: Optional[torch.Tensor] = None,
- zq: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- hidden_states = inputs
-
- if zq is not None:
- hidden_states = self.norm1(hidden_states, zq)
- else:
- hidden_states = self.norm1(hidden_states)
-
- hidden_states = self.nonlinearity(hidden_states)
- hidden_states = self.conv1(hidden_states)
-
- if temb is not None:
- hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
-
- if zq is not None:
- hidden_states = self.norm2(hidden_states, zq)
- else:
- hidden_states = self.norm2(hidden_states)
-
- hidden_states = self.nonlinearity(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.conv2(hidden_states)
-
- if self.in_channels != self.out_channels:
- inputs = self.conv_shortcut(inputs)
-
- hidden_states = hidden_states + inputs
- return hidden_states
-
-
-class CogVideoXDownBlock3D(nn.Module):
- r"""
- A downsampling block used in the CogVideoX model.
-
- Args:
- in_channels (`int`):
- Number of input channels.
- out_channels (`int`, *optional*):
- Number of output channels. If None, defaults to `in_channels`.
- temb_channels (`int`, defaults to `512`):
- Number of time embedding channels.
- num_layers (`int`, defaults to `1`):
- Number of resnet layers.
- dropout (`float`, defaults to `0.0`):
- Dropout rate.
- resnet_eps (`float`, defaults to `1e-6`):
- Epsilon value for normalization layers.
- resnet_act_fn (`str`, defaults to `"swish"`):
- Activation function to use.
- resnet_groups (`int`, defaults to `32`):
- Number of groups to separate the channels into for group normalization.
- add_downsample (`bool`, defaults to `True`):
- Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
- compress_time (`bool`, defaults to `False`):
- Whether or not to downsample across temporal dimension.
- pad_mode (str, defaults to `"first"`):
- Padding mode.
- """
-
- _supports_gradient_checkpointing = True
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- add_downsample: bool = True,
- downsample_padding: int = 0,
- compress_time: bool = False,
- pad_mode: str = "first",
- ):
- super().__init__()
-
- resnets = []
- for i in range(num_layers):
- in_channel = in_channels if i == 0 else out_channels
- resnets.append(
- CogVideoXResnetBlock3D(
- in_channels=in_channel,
- out_channels=out_channels,
- dropout=dropout,
- temb_channels=temb_channels,
- groups=resnet_groups,
- eps=resnet_eps,
- non_linearity=resnet_act_fn,
- pad_mode=pad_mode,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
- self.downsamplers = None
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- CogVideoXDownsample3D(
- out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
- )
- ]
- )
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- temb: Optional[torch.Tensor] = None,
- zq: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def create_forward(*inputs):
- return module(*inputs)
-
- return create_forward
-
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, zq
- )
- else:
- hidden_states = resnet(hidden_states, temb, zq)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- return hidden_states
-
-
-class CogVideoXMidBlock3D(nn.Module):
- r"""
- A middle block used in the CogVideoX model.
-
- Args:
- in_channels (`int`):
- Number of input channels.
- temb_channels (`int`, defaults to `512`):
- Number of time embedding channels.
- dropout (`float`, defaults to `0.0`):
- Dropout rate.
- num_layers (`int`, defaults to `1`):
- Number of resnet layers.
- resnet_eps (`float`, defaults to `1e-6`):
- Epsilon value for normalization layers.
- resnet_act_fn (`str`, defaults to `"swish"`):
- Activation function to use.
- resnet_groups (`int`, defaults to `32`):
- Number of groups to separate the channels into for group normalization.
- spatial_norm_dim (`int`, *optional*):
- The dimension to use for spatial norm if it is to be used instead of group norm.
- pad_mode (str, defaults to `"first"`):
- Padding mode.
- """
-
- _supports_gradient_checkpointing = True
-
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- spatial_norm_dim: Optional[int] = None,
- pad_mode: str = "first",
- ):
- super().__init__()
-
- resnets = []
- for _ in range(num_layers):
- resnets.append(
- CogVideoXResnetBlock3D(
- in_channels=in_channels,
- out_channels=in_channels,
- dropout=dropout,
- temb_channels=temb_channels,
- groups=resnet_groups,
- eps=resnet_eps,
- spatial_norm_dim=spatial_norm_dim,
- non_linearity=resnet_act_fn,
- pad_mode=pad_mode,
- )
- )
- self.resnets = nn.ModuleList(resnets)
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- temb: Optional[torch.Tensor] = None,
- zq: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def create_forward(*inputs):
- return module(*inputs)
-
- return create_forward
-
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, zq
- )
- else:
- hidden_states = resnet(hidden_states, temb, zq)
-
- return hidden_states
-
-
-class CogVideoXUpBlock3D(nn.Module):
- r"""
- An upsampling block used in the CogVideoX model.
-
- Args:
- in_channels (`int`):
- Number of input channels.
- out_channels (`int`, *optional*):
- Number of output channels. If None, defaults to `in_channels`.
- temb_channels (`int`, defaults to `512`):
- Number of time embedding channels.
- dropout (`float`, defaults to `0.0`):
- Dropout rate.
- num_layers (`int`, defaults to `1`):
- Number of resnet layers.
- resnet_eps (`float`, defaults to `1e-6`):
- Epsilon value for normalization layers.
- resnet_act_fn (`str`, defaults to `"swish"`):
- Activation function to use.
- resnet_groups (`int`, defaults to `32`):
- Number of groups to separate the channels into for group normalization.
- spatial_norm_dim (`int`, defaults to `16`):
- The dimension to use for spatial norm if it is to be used instead of group norm.
- add_upsample (`bool`, defaults to `True`):
- Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
- compress_time (`bool`, defaults to `False`):
- Whether or not to downsample across temporal dimension.
- pad_mode (str, defaults to `"first"`):
- Padding mode.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- spatial_norm_dim: int = 16,
- add_upsample: bool = True,
- upsample_padding: int = 1,
- compress_time: bool = False,
- pad_mode: str = "first",
- ):
- super().__init__()
-
- resnets = []
- for i in range(num_layers):
- in_channel = in_channels if i == 0 else out_channels
- resnets.append(
- CogVideoXResnetBlock3D(
- in_channels=in_channel,
- out_channels=out_channels,
- dropout=dropout,
- temb_channels=temb_channels,
- groups=resnet_groups,
- eps=resnet_eps,
- non_linearity=resnet_act_fn,
- spatial_norm_dim=spatial_norm_dim,
- pad_mode=pad_mode,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
- self.upsamplers = None
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [
- CogVideoXUpsample3D(
- out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
- )
- ]
- )
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- temb: Optional[torch.Tensor] = None,
- zq: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- r"""Forward method of the `CogVideoXUpBlock3D` class."""
- for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def create_forward(*inputs):
- return module(*inputs)
-
- return create_forward
-
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, zq
- )
- else:
- hidden_states = resnet(hidden_states, temb, zq)
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
-
- return hidden_states
-
-
-class CogVideoXEncoder3D(nn.Module):
- r"""
- The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
-
- Args:
- in_channels (`int`, *optional*, defaults to 3):
- The number of input channels.
- out_channels (`int`, *optional*, defaults to 3):
- The number of output channels.
- down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
- The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
- options.
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
- The number of output channels for each block.
- act_fn (`str`, *optional*, defaults to `"silu"`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- layers_per_block (`int`, *optional*, defaults to 2):
- The number of layers per block.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups for normalization.
- """
-
- _supports_gradient_checkpointing = True
-
- def __init__(
- self,
- in_channels: int = 3,
- out_channels: int = 16,
- down_block_types: Tuple[str, ...] = (
- "CogVideoXDownBlock3D",
- "CogVideoXDownBlock3D",
- "CogVideoXDownBlock3D",
- "CogVideoXDownBlock3D",
- ),
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
- layers_per_block: int = 3,
- act_fn: str = "silu",
- norm_eps: float = 1e-6,
- norm_num_groups: int = 32,
- dropout: float = 0.0,
- pad_mode: str = "first",
- temporal_compression_ratio: float = 4,
- ):
- super().__init__()
-
- # log2 of temporal_compress_times
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
-
- self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
- self.down_blocks = nn.ModuleList([])
-
- # down blocks
- output_channel = block_out_channels[0]
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
- compress_time = i < temporal_compress_level
-
- if down_block_type == "CogVideoXDownBlock3D":
- down_block = CogVideoXDownBlock3D(
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=0,
- dropout=dropout,
- num_layers=layers_per_block,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- add_downsample=not is_final_block,
- compress_time=compress_time,
- )
- else:
- raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
-
- self.down_blocks.append(down_block)
-
- # mid block
- self.mid_block = CogVideoXMidBlock3D(
- in_channels=block_out_channels[-1],
- temb_channels=0,
- dropout=dropout,
- num_layers=2,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- pad_mode=pad_mode,
- )
-
- self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
- self.conv_act = nn.SiLU()
- self.conv_out = CogVideoXCausalConv3d(
- block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
- )
-
- self.gradient_checkpointing = False
-
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
- r"""The forward method of the `CogVideoXEncoder3D` class."""
- hidden_states = self.conv_in(sample)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- # 1. Down
- for down_block in self.down_blocks:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(down_block), hidden_states, temb, None
- )
-
- # 2. Mid
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block), hidden_states, temb, None
- )
- else:
- # 1. Down
- for down_block in self.down_blocks:
- hidden_states = down_block(hidden_states, temb, None)
-
- # 2. Mid
- hidden_states = self.mid_block(hidden_states, temb, None)
-
- # 3. Post-process
- hidden_states = self.norm_out(hidden_states)
- hidden_states = self.conv_act(hidden_states)
- hidden_states = self.conv_out(hidden_states)
- return hidden_states
-
-
-class CogVideoXDecoder3D(nn.Module):
- r"""
- The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
- sample.
-
- Args:
- in_channels (`int`, *optional*, defaults to 3):
- The number of input channels.
- out_channels (`int`, *optional*, defaults to 3):
- The number of output channels.
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
- The number of output channels for each block.
- act_fn (`str`, *optional*, defaults to `"silu"`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- layers_per_block (`int`, *optional*, defaults to 2):
- The number of layers per block.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups for normalization.
- """
-
- _supports_gradient_checkpointing = True
-
- def __init__(
- self,
- in_channels: int = 16,
- out_channels: int = 3,
- up_block_types: Tuple[str, ...] = (
- "CogVideoXUpBlock3D",
- "CogVideoXUpBlock3D",
- "CogVideoXUpBlock3D",
- "CogVideoXUpBlock3D",
- ),
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
- layers_per_block: int = 3,
- act_fn: str = "silu",
- norm_eps: float = 1e-6,
- norm_num_groups: int = 32,
- dropout: float = 0.0,
- pad_mode: str = "first",
- temporal_compression_ratio: float = 4,
- ):
- super().__init__()
-
- reversed_block_out_channels = list(reversed(block_out_channels))
-
- self.conv_in = CogVideoXCausalConv3d(
- in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
- )
-
- # mid block
- self.mid_block = CogVideoXMidBlock3D(
- in_channels=reversed_block_out_channels[0],
- temb_channels=0,
- num_layers=2,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- spatial_norm_dim=in_channels,
- pad_mode=pad_mode,
- )
-
- # up blocks
- self.up_blocks = nn.ModuleList([])
-
- output_channel = reversed_block_out_channels[0]
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
-
- for i, up_block_type in enumerate(up_block_types):
- prev_output_channel = output_channel
- output_channel = reversed_block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
- compress_time = i < temporal_compress_level
-
- if up_block_type == "CogVideoXUpBlock3D":
- up_block = CogVideoXUpBlock3D(
- in_channels=prev_output_channel,
- out_channels=output_channel,
- temb_channels=0,
- dropout=dropout,
- num_layers=layers_per_block + 1,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- spatial_norm_dim=in_channels,
- add_upsample=not is_final_block,
- compress_time=compress_time,
- pad_mode=pad_mode,
- )
- prev_output_channel = output_channel
- else:
- raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
-
- self.up_blocks.append(up_block)
-
- self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
- self.conv_act = nn.SiLU()
- self.conv_out = CogVideoXCausalConv3d(
- reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
- )
-
- self.gradient_checkpointing = False
-
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
- r"""The forward method of the `CogVideoXDecoder3D` class."""
- hidden_states = self.conv_in(sample)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- # 1. Mid
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block), hidden_states, temb, sample
- )
-
- # 2. Up
- for up_block in self.up_blocks:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block), hidden_states, temb, sample
- )
- else:
- # 1. Mid
- hidden_states = self.mid_block(hidden_states, temb, sample)
-
- # 2. Up
- for up_block in self.up_blocks:
- hidden_states = up_block(hidden_states, temb, sample)
-
- # 3. Post-process
- hidden_states = self.norm_out(hidden_states, sample)
- hidden_states = self.conv_act(hidden_states)
- hidden_states = self.conv_out(hidden_states)
- return hidden_states
-
-
-class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
- r"""
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
- [CogVideoX](https://github.com/THUDM/CogVideo).
-
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
- for all models (such as downloading or saving).
-
- Parameters:
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
- Tuple of downsample block types.
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
- Tuple of upsample block types.
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
- Tuple of block output channels.
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
- scaling_factor (`float`, *optional*, defaults to `1.15258426`):
- The component-wise standard deviation of the trained latent space computed using the first batch of the
- training set. This is used to scale the latent space to have unit variance when training the diffusion
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
- force_upcast (`bool`, *optional*, default to `True`):
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
- """
-
- _supports_gradient_checkpointing = True
- _no_split_modules = ["CogVideoXResnetBlock3D"]
-
- @register_to_config
- def __init__(
- self,
- in_channels: int = 3,
- out_channels: int = 3,
- down_block_types: Tuple[str] = (
- "CogVideoXDownBlock3D",
- "CogVideoXDownBlock3D",
- "CogVideoXDownBlock3D",
- "CogVideoXDownBlock3D",
- ),
- up_block_types: Tuple[str] = (
- "CogVideoXUpBlock3D",
- "CogVideoXUpBlock3D",
- "CogVideoXUpBlock3D",
- "CogVideoXUpBlock3D",
- ),
- block_out_channels: Tuple[int] = (128, 256, 256, 512),
- latent_channels: int = 16,
- layers_per_block: int = 3,
- act_fn: str = "silu",
- norm_eps: float = 1e-6,
- norm_num_groups: int = 32,
- temporal_compression_ratio: float = 4,
- sample_height: int = 480,
- sample_width: int = 720,
- scaling_factor: float = 1.15258426,
- shift_factor: Optional[float] = None,
- latents_mean: Optional[Tuple[float]] = None,
- latents_std: Optional[Tuple[float]] = None,
- force_upcast: float = True,
- use_quant_conv: bool = False,
- use_post_quant_conv: bool = False,
- ):
- super().__init__()
-
- self.encoder = CogVideoXEncoder3D(
- in_channels=in_channels,
- out_channels=latent_channels,
- down_block_types=down_block_types,
- block_out_channels=block_out_channels,
- layers_per_block=layers_per_block,
- act_fn=act_fn,
- norm_eps=norm_eps,
- norm_num_groups=norm_num_groups,
- temporal_compression_ratio=temporal_compression_ratio,
- )
- self.decoder = CogVideoXDecoder3D(
- in_channels=latent_channels,
- out_channels=out_channels,
- up_block_types=up_block_types,
- block_out_channels=block_out_channels,
- layers_per_block=layers_per_block,
- act_fn=act_fn,
- norm_eps=norm_eps,
- norm_num_groups=norm_num_groups,
- temporal_compression_ratio=temporal_compression_ratio,
- )
- self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
- self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
-
- self.use_slicing = False
- self.use_tiling = False
-
- # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
- # recommended because the temporal parts of the VAE, here, are tricky to understand.
- # If you decode X latent frames together, the number of output frames is:
- # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
- #
- # Example with num_latent_frames_batch_size = 2:
- # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
- # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
- # => 6 * 8 = 48 frames
- # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
- # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
- # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
- # => 1 * 9 + 5 * 8 = 49 frames
- # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
- # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
- # number of temporal frames.
- self.num_latent_frames_batch_size = 2
-
- # We make the minimum height and width of sample for tiling half that of the generally supported
- self.tile_sample_min_height = sample_height // 2
- self.tile_sample_min_width = sample_width // 2
- self.tile_latent_min_height = int(
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
- )
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
-
- # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
- # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
- # and so the tiling implementation has only been tested on those specific resolutions.
- self.tile_overlap_factor_height = 1 / 6
- self.tile_overlap_factor_width = 1 / 5
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
- module.gradient_checkpointing = value
-
- def _clear_fake_context_parallel_cache(self):
- for name, module in self.named_modules():
- if isinstance(module, CogVideoXCausalConv3d):
- logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
- module._clear_fake_context_parallel_cache()
-
- def enable_tiling(
- self,
- tile_sample_min_height: Optional[int] = None,
- tile_sample_min_width: Optional[int] = None,
- tile_overlap_factor_height: Optional[float] = None,
- tile_overlap_factor_width: Optional[float] = None,
- ) -> None:
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
-
- Args:
- tile_sample_min_height (`int`, *optional*):
- The minimum height required for a sample to be separated into tiles across the height dimension.
- tile_sample_min_width (`int`, *optional*):
- The minimum width required for a sample to be separated into tiles across the width dimension.
- tile_overlap_factor_height (`int`, *optional*):
- The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
- no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
- value might cause more tiles to be processed leading to slow down of the decoding process.
- tile_overlap_factor_width (`int`, *optional*):
- The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
- are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
- value might cause more tiles to be processed leading to slow down of the decoding process.
- """
- self.use_tiling = True
- self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
- self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
- self.tile_latent_min_height = int(
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
- )
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
- self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
- self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
-
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
- @apply_forward_hook
- def encode(
- self, x: torch.Tensor, return_dict: bool = True
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
- """
- Encode a batch of images into latents.
-
- Args:
- x (`torch.Tensor`): Input batch of images.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
-
- Returns:
- The latent representations of the encoded images. If `return_dict` is True, a
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
- """
- batch_size, num_channels, num_frames, height, width = x.shape
- if num_frames == 1:
- h = self.encoder(x)
- if self.quant_conv is not None:
- h = self.quant_conv(h)
- posterior = DiagonalGaussianDistribution(h)
- else:
- frame_batch_size = 4
- h = []
- for i in range(num_frames // frame_batch_size):
- remaining_frames = num_frames % frame_batch_size
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
- end_frame = frame_batch_size * (i + 1) + remaining_frames
- z_intermediate = x[:, :, start_frame:end_frame]
- z_intermediate = self.encoder(z_intermediate)
- if self.quant_conv is not None:
- z_intermediate = self.quant_conv(z_intermediate)
- h.append(z_intermediate)
- self._clear_fake_context_parallel_cache()
- h = torch.cat(h, dim=2)
- posterior = DiagonalGaussianDistribution(h)
- if not return_dict:
- return (posterior,)
- return AutoencoderKLOutput(latent_dist=posterior)
-
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
- batch_size, num_channels, num_frames, height, width = z.shape
-
- if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
- return self.tiled_decode(z, return_dict=return_dict)
-
- if num_frames == 1:
- dec = []
- z_intermediate = z
- if self.post_quant_conv is not None:
- z_intermediate = self.post_quant_conv(z_intermediate)
- z_intermediate = self.decoder(z_intermediate)
- dec.append(z_intermediate)
- else:
- frame_batch_size = self.num_latent_frames_batch_size
- dec = []
- for i in range(num_frames // frame_batch_size):
- remaining_frames = num_frames % frame_batch_size
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
- end_frame = frame_batch_size * (i + 1) + remaining_frames
- z_intermediate = z[:, :, start_frame:end_frame]
- if self.post_quant_conv is not None:
- z_intermediate = self.post_quant_conv(z_intermediate)
- z_intermediate = self.decoder(z_intermediate)
- dec.append(z_intermediate)
-
- self._clear_fake_context_parallel_cache()
- dec = torch.cat(dec, dim=2)
-
- if not return_dict:
- return (dec,)
-
- return DecoderOutput(sample=dec)
-
- @apply_forward_hook
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
- """
- Decode a batch of images.
-
- Args:
- z (`torch.Tensor`): Input batch of latent vectors.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.vae.DecoderOutput`] or `tuple`:
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
- returned.
- """
- if self.use_slicing and z.shape[0] > 1:
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
- decoded = torch.cat(decoded_slices)
- else:
- decoded = self._decode(z).sample
-
- if not return_dict:
- return (decoded,)
- return DecoderOutput(sample=decoded)
-
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
- for y in range(blend_extent):
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
- y / blend_extent
- )
- return b
-
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
- blend_extent = min(a.shape[4], b.shape[4], blend_extent)
- for x in range(blend_extent):
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
- x / blend_extent
- )
- return b
-
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
- r"""
- Decode a batch of images using a tiled decoder.
-
- Args:
- z (`torch.Tensor`): Input batch of latent vectors.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.vae.DecoderOutput`] or `tuple`:
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
- returned.
- """
- # Rough memory assessment:
- # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
- # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
- # - Assume fp16 (2 bytes per value).
- # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
- #
- # Memory assessment when using tiling:
- # - Assume everything as above but now HxW is 240x360 by tiling in half
- # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
-
- batch_size, num_channels, num_frames, height, width = z.shape
-
- overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
- overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
- blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
- blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
- row_limit_height = self.tile_sample_min_height - blend_extent_height
- row_limit_width = self.tile_sample_min_width - blend_extent_width
- frame_batch_size = self.num_latent_frames_batch_size
-
- # Split z into overlapping tiles and decode them separately.
- # The tiles have an overlap to avoid seams between tiles.
- rows = []
- for i in range(0, height, overlap_height):
- row = []
- for j in range(0, width, overlap_width):
- time = []
- for k in range(num_frames // frame_batch_size):
- remaining_frames = num_frames % frame_batch_size
- start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
- end_frame = frame_batch_size * (k + 1) + remaining_frames
- tile = z[
- :,
- :,
- start_frame:end_frame,
- i : i + self.tile_latent_min_height,
- j : j + self.tile_latent_min_width,
- ]
- if self.post_quant_conv is not None:
- tile = self.post_quant_conv(tile)
- tile = self.decoder(tile)
- time.append(tile)
- self._clear_fake_context_parallel_cache()
- row.append(torch.cat(time, dim=2))
- rows.append(row)
-
- result_rows = []
- for i, row in enumerate(rows):
- result_row = []
- for j, tile in enumerate(row):
- # blend the above tile and the left tile
- # to the current tile and add the current tile to the result row
- if i > 0:
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
- if j > 0:
- tile = self.blend_h(row[j - 1], tile, blend_extent_width)
- result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
- result_rows.append(torch.cat(result_row, dim=4))
-
- dec = torch.cat(result_rows, dim=3)
-
- if not return_dict:
- return (dec,)
-
- return DecoderOutput(sample=dec)
-
- def forward(
- self,
- sample: torch.Tensor,
- sample_posterior: bool = False,
- return_dict: bool = True,
- generator: Optional[torch.Generator] = None,
- ) -> Union[torch.Tensor, torch.Tensor]:
- x = sample
- posterior = self.encode(x).latent_dist
- if sample_posterior:
- z = posterior.sample(generator=generator)
- else:
- z = posterior.mode()
- dec = self.decode(z)
- if not return_dict:
- return (dec,)
- return dec
diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py
deleted file mode 100644
index f598147..0000000
--- a/cogvideox_fun/pipeline_cogvideox_control.py
+++ /dev/null
@@ -1,866 +0,0 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import math
-from dataclasses import dataclass
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-
-from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
-from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
-from diffusers.models.embeddings import get_3d_rotary_pos_embed
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
-from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from diffusers.utils import BaseOutput, logging, replace_example_docstring
-from diffusers.utils.torch_utils import randn_tensor
-from diffusers.video_processor import VideoProcessor
-from diffusers.image_processor import VaeImageProcessor
-from einops import rearrange
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-EXAMPLE_DOC_STRING = """
- Examples:
- ```python
- >>> import torch
- >>> from diffusers import CogVideoX_Fun_Pipeline
- >>> from diffusers.utils import export_to_video
-
- >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
- >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
- >>> prompt = (
- ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
- ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
- ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
- ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
- ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
- ... "atmosphere of this unique musical performance."
- ... )
- >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
- >>> export_to_video(video, "output.mp4", fps=8)
- ```
-"""
-
-
-# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
-def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
- tw = tgt_width
- th = tgt_height
- h, w = src
- r = h / w
- if r > (th / tw):
- resize_height = th
- resize_width = int(round(th / h * w))
- else:
- resize_width = tw
- resize_height = int(round(tw / w * h))
-
- crop_top = int(round((th - resize_height) / 2.0))
- crop_left = int(round((tw - resize_width) / 2.0))
-
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
-
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
- must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
- `num_inference_steps` and `sigmas` must be `None`.
- sigmas (`List[float]`, *optional*):
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
- `num_inference_steps` and `timesteps` must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None and sigmas is not None:
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- elif sigmas is not None:
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accept_sigmas:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" sigmas schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
-@dataclass
-class CogVideoX_Fun_PipelineOutput(BaseOutput):
- r"""
- Output class for CogVideo pipelines.
-
- Args:
- video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
- denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
- `(batch_size, num_frames, channels, height, width)`.
- """
-
- videos: torch.Tensor
-
-
-class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
- r"""
- Pipeline for text-to-video generation using CogVideoX.
-
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
-
- Args:
- vae ([`AutoencoderKL`]):
- Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
- transformer ([`CogVideoXTransformer3DModel`]):
- A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
- scheduler ([`SchedulerMixin`]):
- A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
- """
-
- _optional_components = []
- model_cpu_offload_seq = "vae->transformer->vae"
-
- _callback_tensor_inputs = [
- "latents",
- "prompt_embeds",
- "negative_prompt_embeds",
- ]
-
- def __init__(
- self,
- vae: AutoencoderKLCogVideoX,
- transformer: CogVideoXTransformer3DModel,
- scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
- ):
- super().__init__()
-
- self.register_modules(
- vae=vae, transformer=transformer, scheduler=scheduler
- )
- self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
- self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
-
- self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
-
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.mask_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
- )
-
- def prepare_latents(
- self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
- latents=None, freenoise=True, context_size=None, context_overlap=None
- ):
- shape = (
- batch_size,
- (num_frames - 1) // self.vae_scale_factor_temporal + 1,
- num_channels_latents,
- height // self.vae_scale_factor_spatial,
- width // self.vae_scale_factor_spatial,
- )
- if isinstance(generator, list) and len(generator) != batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
- )
- noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype)
- if freenoise:
- print("Applying FreeNoise")
- # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
- video_length = num_frames // 4
- delta = context_size - context_overlap
- for start_idx in range(0, video_length-context_size, delta):
- # start_idx corresponds to the beginning of a context window
- # goal: place shuffled in the delta region right after the end of the context window
- # if space after context window is not enough to place the noise, adjust and finish
- place_idx = start_idx + context_size
- # if place_idx is outside the valid indexes, we are already finished
- if place_idx >= video_length:
- break
- end_idx = place_idx - 1
- #print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta)
-
- # if there is not enough room to copy delta amount of indexes, copy limited amount and finish
- if end_idx + delta >= video_length:
- final_delta = video_length - place_idx
- # generate list of indexes in final delta region
- list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
- # shuffle list
- list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
- # apply shuffled indexes
- noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :]
- break
- # otherwise, do normal behavior
- # generate list of indexes in delta region
- list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
- # shuffle list
- list_idx = list_idx[torch.randperm(delta, generator=generator)]
- # apply shuffled indexes
- #print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx)
- noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :]
- if latents is None:
- latents = noise.to(device)
- else:
- latents = latents.to(device)
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
- latent_timestep = timesteps[:1]
-
- noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
- frames_needed = noise.shape[1]
- current_frames = latents.shape[1]
-
- if frames_needed > current_frames:
- repeat_factor = frames_needed // current_frames
- additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
- latents = torch.cat((latents, additional_frame), dim=1)
- elif frames_needed < current_frames:
- latents = latents[:, :frames_needed, :, :, :]
-
- latents = self.scheduler.add_noise(latents, noise, latent_timestep)
- latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
- return latents, timesteps, noise
-
- def prepare_control_latents(
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
- ):
- # resize the mask to latents shape as we concatenate the mask to the latents
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
- # and half precision
-
- if mask is not None:
- mask = mask.to(device=device, dtype=self.vae.dtype)
- bs = 1
- new_mask = []
- for i in range(0, mask.shape[0], bs):
- mask_bs = mask[i : i + bs]
- mask_bs = self.vae.encode(mask_bs)[0]
- mask_bs = mask_bs.mode()
- new_mask.append(mask_bs)
- mask = torch.cat(new_mask, dim = 0)
- mask = mask * self.vae.config.scaling_factor
-
- if masked_image is not None:
- masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
- bs = 1
- new_mask_pixel_values = []
- for i in range(0, masked_image.shape[0], bs):
- mask_pixel_values_bs = masked_image[i : i + bs]
- mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
- mask_pixel_values_bs = mask_pixel_values_bs.mode()
- new_mask_pixel_values.append(mask_pixel_values_bs)
- masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
- masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
- else:
- masked_image_latents = None
-
- return mask, masked_image_latents
-
- def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
- latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
- latents = 1 / self.vae.config.scaling_factor * latents
-
- frames = self.vae.decode(latents).sample
- frames = (frames / 2 + 0.5).clamp(0, 1)
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
- frames = frames.cpu().float().numpy()
- return frames
-
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
- def prepare_extra_step_kwargs(self, generator, eta):
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
-
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
- if accepts_eta:
- extra_step_kwargs["eta"] = eta
-
- # check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
- if accepts_generator:
- extra_step_kwargs["generator"] = generator
- return extra_step_kwargs
-
- # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
- def check_inputs(
- self,
- prompt,
- height,
- width,
- negative_prompt,
- callback_on_step_end_tensor_inputs,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- ):
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
-
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
- if prompt is not None and prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
- " only forward one of the two."
- )
- elif prompt is None and prompt_embeds is None:
- raise ValueError(
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
- )
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-
- if prompt is not None and negative_prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
- )
-
- if negative_prompt is not None and negative_prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
- )
-
- if prompt_embeds is not None and negative_prompt_embeds is not None:
- if prompt_embeds.shape != negative_prompt_embeds.shape:
- raise ValueError(
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
- f" {negative_prompt_embeds.shape}."
- )
-
- def fuse_qkv_projections(self) -> None:
- r"""Enables fused QKV projections."""
- self.fusing_transformer = True
- self.transformer.fuse_qkv_projections()
-
- def unfuse_qkv_projections(self) -> None:
- r"""Disable QKV projection fusion if enabled."""
- if not self.fusing_transformer:
- logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
- else:
- self.transformer.unfuse_qkv_projections()
- self.fusing_transformer = False
-
- def _prepare_rotary_positional_embeddings(
- self,
- height: int,
- width: int,
- num_frames: int,
- device: torch.device,
- start_frame: Optional[int] = None,
- end_frame: Optional[int] = None,
- context_frames: Optional[int] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
-
- grid_crops_coords = get_resize_crop_region_for_grid(
- (grid_height, grid_width), base_size_width, base_size_height
- )
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=self.transformer.config.attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=num_frames,
- use_real=True,
- )
-
- if start_frame is not None or context_frames is not None:
- freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
- freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
- if context_frames is not None:
- freqs_cos = freqs_cos[context_frames]
- freqs_sin = freqs_sin[context_frames]
- else:
- freqs_cos = freqs_cos[start_frame:end_frame]
- freqs_sin = freqs_sin[start_frame:end_frame]
-
- freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
- freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
-
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
- return freqs_cos, freqs_sin
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
- @property
- def interrupt(self):
- return self._interrupt
-
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
- def get_timesteps(self, num_inference_steps, strength, device):
- # get the original timestep using init_timestep
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
-
- t_start = max(num_inference_steps - init_timestep, 0)
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
-
- return timesteps, num_inference_steps - t_start
-
- @torch.no_grad()
- @replace_example_docstring(EXAMPLE_DOC_STRING)
- def __call__(
- self,
- prompt: Optional[Union[str, List[str]]] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- height: int = 480,
- width: int = 720,
- video: Union[torch.FloatTensor] = None,
- control_video: Union[torch.FloatTensor] = None,
- num_frames: int = 49,
- num_inference_steps: int = 50,
- timesteps: Optional[List[int]] = None,
- guidance_scale: float = 6,
- use_dynamic_cfg: bool = False,
- denoise_strength: float = 1.0,
- num_videos_per_prompt: int = 1,
- eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.FloatTensor] = None,
- prompt_embeds: Optional[torch.FloatTensor] = None,
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- output_type: str = "numpy",
- return_dict: bool = False,
- callback_on_step_end: Optional[
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
- ] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- max_sequence_length: int = 226,
- comfyui_progressbar: bool = False,
- control_strength: float = 1.0,
- control_start_percent: float = 0.0,
- control_end_percent: float = 1.0,
- scheduler_name: str = "DPM",
- context_schedule: Optional[str] = None,
- context_frames: Optional[int] = None,
- context_stride: Optional[int] = None,
- context_overlap: Optional[int] = None,
- freenoise: Optional[bool] = True,
- tora: Optional[dict] = None,
- ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
- """
- Function invoked when calling the pipeline for generation.
-
- Args:
- prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
- instead.
- negative_prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
- less than `1`).
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
- num_frames (`int`, defaults to `48`):
- Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
- contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
- num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
- needs to be satisfied is that of divisibility mentioned above.
- num_inference_steps (`int`, *optional*, defaults to 50):
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
- expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
- The number of videos to generate per prompt.
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
- to make generation deterministic.
- latents (`torch.FloatTensor`, *optional*):
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
- argument.
- output_type (`str`, *optional*, defaults to `"pil"`):
- The output format of the generate image. Choose between
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
- of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
- max_sequence_length (`int`, defaults to `226`):
- Maximum sequence length in encoded prompt. Must be consistent with
- `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
-
- Examples:
-
- Returns:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
- `tuple`. When returning a tuple, the first element is a list with the generated images.
- """
-
- # if num_frames > 49:
- # raise ValueError(
- # "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
- # )
-
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
-
- height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
- width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
- num_videos_per_prompt = 1
-
- # 1. Check inputs. Raise error if not correct
- self.check_inputs(
- prompt,
- height,
- width,
- negative_prompt,
- callback_on_step_end_tensor_inputs,
- prompt_embeds,
- negative_prompt_embeds,
- )
- self._guidance_scale = guidance_scale
- self._interrupt = False
-
- # 2. Default call parameters
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
- device = self._execution_device
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
-
- if do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
-
- # 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
- self._num_timesteps = len(timesteps)
- if comfyui_progressbar:
- from comfy.utils import ProgressBar
- pbar = ProgressBar(num_inference_steps + 2)
-
- # 5. Prepare latents.
- latent_channels = self.vae.config.latent_channels
- latents, timesteps, noise = self.prepare_latents(
- batch_size * num_videos_per_prompt,
- latent_channels,
- num_frames,
- height,
- width,
- self.vae.dtype,
- device,
- generator,
- timesteps,
- denoise_strength,
- num_inference_steps,
- latents,
- context_size=context_frames,
- context_overlap=context_overlap,
- freenoise=freenoise,
- )
- if comfyui_progressbar:
- pbar.update(1)
-
-
- control_video_latents_input = (
- torch.cat([control_video] * 2) if do_classifier_free_guidance else control_video
- )
- control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
-
- control_latents = control_latents * control_strength
-
- if comfyui_progressbar:
- pbar.update(1)
-
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
-
-
-
- # 8. Denoising loop
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
-
- if context_schedule is not None:
- print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
- use_context_schedule = True
- from .context import get_context_scheduler
- context = get_context_scheduler(context_schedule)
-
- else:
- use_context_schedule = False
- print(" context schedule disabled")
- # 7. Create rotary embeds if required
- image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
- if self.transformer.config.use_rotary_positional_embeddings
- else None
- )
- if tora is not None and do_classifier_free_guidance:
- video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
-
- if tora is not None:
- for module in self.transformer.fuser_list:
- for param in module.parameters():
- param.data = param.data.to(device)
-
- with self.progress_bar(total=num_inference_steps) as progress_bar:
- # for DPM-solver++
- old_pred_original_sample = None
- for i, t in enumerate(timesteps):
- if self.interrupt:
- continue
- if use_context_schedule:
-
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
- # Calculate the current step percentage
- current_step_percentage = i / num_inference_steps
-
- # Determine if control_latents should be applied
- apply_control = control_start_percent <= current_step_percentage <= control_end_percent
- current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latent_model_input.shape[0])
-
- context_queue = list(context(
- i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
- ))
- counter = torch.zeros_like(latent_model_input)
- noise_pred = torch.zeros_like(latent_model_input)
-
- image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
- if self.transformer.config.use_rotary_positional_embeddings
- else None
- )
-
- for c in context_queue:
- partial_latent_model_input = latent_model_input[:, c, :, :, :]
- partial_control_latents = current_control_latents[:, c, :, :, :]
-
- # predict noise model_output
- noise_pred[:, c, :, :, :] += self.transformer(
- hidden_states=partial_latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- control_latents=partial_control_latents,
- )[0]
-
- counter[:, c, :, :, :] += 1
- noise_pred = noise_pred.float()
-
- noise_pred /= counter
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- if not isinstance(self.scheduler, CogVideoXDPMScheduler):
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- else:
- latents, old_pred_original_sample = self.scheduler.step(
- noise_pred,
- old_pred_original_sample,
- t,
- timesteps[i - 1] if i > 0 else None,
- latents,
- **extra_step_kwargs,
- return_dict=False,
- )
- latents = latents.to(prompt_embeds.dtype)
-
- # call the callback, if provided
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- if comfyui_progressbar:
- pbar.update(1)
- else:
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
- # Calculate the current step percentage
- current_step_percentage = i / num_inference_steps
-
- # Determine if control_latents should be applied
- apply_control = control_start_percent <= current_step_percentage <= control_end_percent
- current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latent_model_input.shape[0])
-
- # predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- control_latents=current_control_latents,
- video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
-
- )[0]
- noise_pred = noise_pred.float()
-
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- if not isinstance(self.scheduler, CogVideoXDPMScheduler):
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- else:
- latents, old_pred_original_sample = self.scheduler.step(
- noise_pred,
- old_pred_original_sample,
- t,
- timesteps[i - 1] if i > 0 else None,
- latents,
- **extra_step_kwargs,
- return_dict=False,
- )
- latents = latents.to(prompt_embeds.dtype)
-
- # call the callback, if provided
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- if comfyui_progressbar:
- pbar.update(1)
-
- # if output_type == "numpy":
- # video = self.decode_latents(latents)
- # elif not output_type == "latent":
- # video = self.decode_latents(latents)
- # video = self.video_processor.postprocess_video(video=video, output_type=output_type)
- # else:
- # video = latents
-
- # Offload all models
- self.maybe_free_model_hooks()
-
- # if not return_dict:
- # video = torch.from_numpy(video)
-
- return latents
\ No newline at end of file
diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py
deleted file mode 100644
index a6f0e9e..0000000
--- a/cogvideox_fun/pipeline_cogvideox_inpaint.py
+++ /dev/null
@@ -1,1037 +0,0 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import math
-from dataclasses import dataclass
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-
-from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
-from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
-from diffusers.models.embeddings import get_3d_rotary_pos_embed
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
-from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from diffusers.utils import BaseOutput, logging, replace_example_docstring
-from diffusers.utils.torch_utils import randn_tensor
-from diffusers.video_processor import VideoProcessor
-from diffusers.image_processor import VaeImageProcessor
-from einops import rearrange
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-EXAMPLE_DOC_STRING = """
- Examples:
- ```python
- >>> import torch
- >>> from diffusers import CogVideoX_Fun_Pipeline
- >>> from diffusers.utils import export_to_video
-
- >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
- >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
- >>> prompt = (
- ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
- ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
- ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
- ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
- ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
- ... "atmosphere of this unique musical performance."
- ... )
- >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
- >>> export_to_video(video, "output.mp4", fps=8)
- ```
-"""
-
-
-# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
-def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
- tw = tgt_width
- th = tgt_height
- h, w = src
- r = h / w
- if r > (th / tw):
- resize_height = th
- resize_width = int(round(th / h * w))
- else:
- resize_width = tw
- resize_height = int(round(tw / w * h))
-
- crop_top = int(round((th - resize_height) / 2.0))
- crop_left = int(round((tw - resize_width) / 2.0))
-
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
-
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
- must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
- `num_inference_steps` and `sigmas` must be `None`.
- sigmas (`List[float]`, *optional*):
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
- `num_inference_steps` and `timesteps` must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None and sigmas is not None:
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- elif sigmas is not None:
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accept_sigmas:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" sigmas schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
-def resize_mask(mask, latent, process_first_frame_only=True):
- latent_size = latent.size()
- batch_size, channels, num_frames, height, width = mask.shape
-
- if process_first_frame_only:
- target_size = list(latent_size[2:])
- target_size[0] = 1
- first_frame_resized = F.interpolate(
- mask[:, :, 0:1, :, :],
- size=target_size,
- mode='trilinear',
- align_corners=False
- )
-
- target_size = list(latent_size[2:])
- target_size[0] = target_size[0] - 1
- if target_size[0] != 0:
- remaining_frames_resized = F.interpolate(
- mask[:, :, 1:, :, :],
- size=target_size,
- mode='trilinear',
- align_corners=False
- )
- resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
- else:
- resized_mask = first_frame_resized
- else:
- target_size = list(latent_size[2:])
- resized_mask = F.interpolate(
- mask,
- size=target_size,
- mode='trilinear',
- align_corners=False
- )
- return resized_mask
-
-def add_noise_to_reference_video(image, ratio=None):
- if ratio is None:
- sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
- sigma = torch.exp(sigma).to(image.dtype)
- else:
- sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
-
- image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
- image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
- image = image + image_noise
- return image
-
-@dataclass
-class CogVideoX_Fun_PipelineOutput(BaseOutput):
- r"""
- Output class for CogVideo pipelines.
-
- Args:
- video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
- denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
- `(batch_size, num_frames, channels, height, width)`.
- """
-
- videos: torch.Tensor
-
-
-class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
- r"""
- Pipeline for text-to-video generation using CogVideoX.
-
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
-
- Args:
- vae ([`AutoencoderKL`]):
- Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
- transformer ([`CogVideoXTransformer3DModel`]):
- A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
- scheduler ([`SchedulerMixin`]):
- A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
- """
-
- _optional_components = []
- model_cpu_offload_seq = "vae->transformer->vae"
-
- _callback_tensor_inputs = [
- "latents",
- "prompt_embeds",
- "negative_prompt_embeds",
- ]
-
- def __init__(
- self,
- vae: AutoencoderKLCogVideoX,
- transformer: CogVideoXTransformer3DModel,
- scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
- ):
- super().__init__()
-
- self.register_modules(
- vae=vae, transformer=transformer, scheduler=scheduler
- )
- self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
- self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
-
- self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
-
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.mask_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
- )
-
- def prepare_latents(
- self,
- batch_size,
- num_channels_latents,
- height,
- width,
- video_length,
- dtype,
- device,
- generator,
- latents=None,
- video=None,
- timestep=None,
- is_strength_max=True,
- return_noise=False,
- return_video_latents=False,
- context_size=None,
- context_overlap=None,
- freenoise=False,
- ):
- shape = (
- batch_size,
- (video_length - 1) // self.vae_scale_factor_temporal + 1,
- num_channels_latents,
- height // self.vae_scale_factor_spatial,
- width // self.vae_scale_factor_spatial,
- )
- if isinstance(generator, list) and len(generator) != batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
- )
-
- if return_video_latents or (latents is None and not is_strength_max):
- video = video.to(device=device, dtype=self.vae.dtype)
-
- bs = 1
- new_video = []
- for i in range(0, video.shape[0], bs):
- video_bs = video[i : i + bs]
- video_bs = self.vae.encode(video_bs)[0]
- video_bs = video_bs.sample()
- new_video.append(video_bs)
- video = torch.cat(new_video, dim = 0)
- video = video * self.vae.config.scaling_factor
-
- video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
- video_latents = video_latents.to(device=device, dtype=dtype)
- video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
-
- if latents is None:
- noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=dtype)
- if freenoise:
- print("Applying FreeNoise")
- # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
- video_length = video_length // 4
- delta = context_size - context_overlap
- for start_idx in range(0, video_length-context_size, delta):
- # start_idx corresponds to the beginning of a context window
- # goal: place shuffled in the delta region right after the end of the context window
- # if space after context window is not enough to place the noise, adjust and finish
- place_idx = start_idx + context_size
- # if place_idx is outside the valid indexes, we are already finished
- if place_idx >= video_length:
- break
- end_idx = place_idx - 1
- #print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta)
-
- # if there is not enough room to copy delta amount of indexes, copy limited amount and finish
- if end_idx + delta >= video_length:
- final_delta = video_length - place_idx
- # generate list of indexes in final delta region
- list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
- # shuffle list
- list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
- # apply shuffled indexes
- noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :]
- break
- # otherwise, do normal behavior
- # generate list of indexes in delta region
- list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
- # shuffle list
- list_idx = list_idx[torch.randperm(delta, generator=generator)]
- # apply shuffled indexes
- #print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx)
- noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :]
-
- # if strength is 1. then initialise the latents to noise, else initial to image + noise
- latents = noise if is_strength_max else self.scheduler.add_noise(video_latents.to(noise), noise, timestep)
- # if pure noise then scale the initial latents by the Scheduler's init sigma
- latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
- latents = latents.to(device)
- else:
- noise = latents.to(device)
- latents = noise * self.scheduler.init_noise_sigma
-
- # scale the initial noise by the standard deviation required by the scheduler
- outputs = (latents,)
-
- if return_noise:
- outputs += (noise,)
-
- if return_video_latents:
- outputs += (video_latents,)
-
- return outputs
-
- def prepare_mask_latents(
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
- ):
- # resize the mask to latents shape as we concatenate the mask to the latents
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
- # and half precision
-
- if mask is not None:
- mask = mask.to(device=device, dtype=self.vae.dtype)
- bs = 1
- new_mask = []
- for i in range(0, mask.shape[0], bs):
- mask_bs = mask[i : i + bs]
- mask_bs = self.vae.encode(mask_bs)[0]
- mask_bs = mask_bs.mode()
- new_mask.append(mask_bs)
- mask = torch.cat(new_mask, dim = 0)
- mask = mask * self.vae.config.scaling_factor
-
- if masked_image is not None:
- if self.transformer.config.add_noise_in_inpaint_model:
- masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
- masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
- bs = 1
- new_mask_pixel_values = []
- for i in range(0, masked_image.shape[0], bs):
- mask_pixel_values_bs = masked_image[i : i + bs]
- mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
- mask_pixel_values_bs = mask_pixel_values_bs.mode()
- new_mask_pixel_values.append(mask_pixel_values_bs)
- masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
- masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
- else:
- masked_image_latents = None
-
- return mask, masked_image_latents
-
- def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
- latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
- latents = 1 / self.vae.config.scaling_factor * latents
-
- frames = self.vae.decode(latents).sample
- frames = (frames / 2 + 0.5).clamp(0, 1)
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
- frames = frames.cpu().float().numpy()
- return frames
-
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
- def prepare_extra_step_kwargs(self, generator, eta):
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
-
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
- if accepts_eta:
- extra_step_kwargs["eta"] = eta
-
- # check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
- if accepts_generator:
- extra_step_kwargs["generator"] = generator
- return extra_step_kwargs
-
- # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
- def check_inputs(
- self,
- prompt,
- height,
- width,
- negative_prompt,
- callback_on_step_end_tensor_inputs,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- ):
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
-
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
- if prompt is not None and prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
- " only forward one of the two."
- )
- elif prompt is None and prompt_embeds is None:
- raise ValueError(
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
- )
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-
- if prompt is not None and negative_prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
- )
-
- if negative_prompt is not None and negative_prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
- )
-
- if prompt_embeds is not None and negative_prompt_embeds is not None:
- if prompt_embeds.shape != negative_prompt_embeds.shape:
- raise ValueError(
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
- f" {negative_prompt_embeds.shape}."
- )
-
- def fuse_qkv_projections(self) -> None:
- r"""Enables fused QKV projections."""
- self.fusing_transformer = True
- self.transformer.fuse_qkv_projections()
-
- def unfuse_qkv_projections(self) -> None:
- r"""Disable QKV projection fusion if enabled."""
- if not self.fusing_transformer:
- logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
- else:
- self.transformer.unfuse_qkv_projections()
- self.fusing_transformer = False
-
- def _prepare_rotary_positional_embeddings(
- self,
- height: int,
- width: int,
- num_frames: int,
- device: torch.device,
- start_frame: Optional[int] = None,
- end_frame: Optional[int] = None,
- context_frames: Optional[int] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
-
- grid_crops_coords = get_resize_crop_region_for_grid(
- (grid_height, grid_width), base_size_width, base_size_height
- )
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=self.transformer.config.attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=num_frames,
- use_real=True,
- )
-
- if start_frame is not None or context_frames is not None:
- freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
- freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
- if context_frames is not None:
- freqs_cos = freqs_cos[context_frames]
- freqs_sin = freqs_sin[context_frames]
- else:
- freqs_cos = freqs_cos[start_frame:end_frame]
- freqs_sin = freqs_sin[start_frame:end_frame]
-
- freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
- freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
-
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
- return freqs_cos, freqs_sin
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
- @property
- def interrupt(self):
- return self._interrupt
-
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
- def get_timesteps(self, num_inference_steps, strength, device):
- # get the original timestep using init_timestep
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
-
- t_start = max(num_inference_steps - init_timestep, 0)
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
-
- return timesteps, num_inference_steps - t_start
-
- @torch.no_grad()
- @replace_example_docstring(EXAMPLE_DOC_STRING)
- def __call__(
- self,
- prompt: Optional[Union[str, List[str]]] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- height: int = 480,
- width: int = 720,
- video: Union[torch.FloatTensor] = None,
- mask_video: Union[torch.FloatTensor] = None,
- masked_video_latents: Union[torch.FloatTensor] = None,
- num_frames: int = 49,
- num_inference_steps: int = 50,
- timesteps: Optional[List[int]] = None,
- guidance_scale: float = 6,
- use_dynamic_cfg: bool = False,
- num_videos_per_prompt: int = 1,
- eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.FloatTensor] = None,
- prompt_embeds: Optional[torch.FloatTensor] = None,
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- output_type: str = "numpy",
- return_dict: bool = False,
- callback_on_step_end: Optional[
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
- ] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- max_sequence_length: int = 226,
- strength: float = 1,
- noise_aug_strength: float = 0.0563,
- comfyui_progressbar: bool = False,
- context_schedule: Optional[str] = None,
- context_frames: Optional[int] = None,
- context_stride: Optional[int] = None,
- context_overlap: Optional[int] = None,
- freenoise: Optional[bool] = True,
- tora: Optional[dict] = None,
- ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
- """
- Function invoked when calling the pipeline for generation.
-
- Args:
- prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
- instead.
- negative_prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
- less than `1`).
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
- num_frames (`int`, defaults to `48`):
- Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
- contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
- num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
- needs to be satisfied is that of divisibility mentioned above.
- num_inference_steps (`int`, *optional*, defaults to 50):
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
- expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
- The number of videos to generate per prompt.
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
- to make generation deterministic.
- latents (`torch.FloatTensor`, *optional*):
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
- argument.
- output_type (`str`, *optional*, defaults to `"pil"`):
- The output format of the generate image. Choose between
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
- of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
- max_sequence_length (`int`, defaults to `226`):
- Maximum sequence length in encoded prompt. Must be consistent with
- `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
-
- Examples:
-
- Returns:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
- `tuple`. When returning a tuple, the first element is a list with the generated images.
- """
-
- # if num_frames > 49:
- # raise ValueError(
- # "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
- # )
-
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
-
- height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
- width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
- num_videos_per_prompt = 1
-
- # 1. Check inputs. Raise error if not correct
- self.check_inputs(
- prompt,
- height,
- width,
- negative_prompt,
- callback_on_step_end_tensor_inputs,
- prompt_embeds,
- negative_prompt_embeds,
- )
- self._guidance_scale = guidance_scale
- self._interrupt = False
-
- # 2. Default call parameters
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
- device = self._execution_device
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
-
- if do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
-
- # 4. set timesteps
- self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps, num_inference_steps = self.get_timesteps(
- num_inference_steps=num_inference_steps, strength=strength, device=device
- )
- self._num_timesteps = len(timesteps)
- if comfyui_progressbar:
- from comfy.utils import ProgressBar
- pbar = ProgressBar(num_inference_steps + 2)
- # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
- latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
- # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
- is_strength_max = strength == 1.0
-
- # 5. Prepare latents.
- if video is not None:
- video_length = video.shape[2]
- init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
- init_video = init_video.to(dtype=torch.float32)
- init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
- else:
- init_video = None
-
- num_channels_latents = self.vae.config.latent_channels
- num_channels_transformer = self.transformer.config.in_channels
- return_image_latents = num_channels_transformer == num_channels_latents
-
- self.vae.to(device)
-
- latents_outputs = self.prepare_latents(
- batch_size * num_videos_per_prompt,
- num_channels_latents,
- height,
- width,
- video_length,
- self.vae.dtype,
- device,
- generator,
- latents,
- video=init_video,
- timestep=latent_timestep,
- is_strength_max=is_strength_max,
- return_noise=True,
- return_video_latents=return_image_latents,
- context_size=context_frames,
- context_overlap=context_overlap,
- freenoise=freenoise,
- )
- if return_image_latents:
- latents, noise, image_latents = latents_outputs
- else:
- latents, noise = latents_outputs
- if comfyui_progressbar:
- pbar.update(1)
-
- if mask_video is not None:
- if (mask_video == 255).all():
- mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
- masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
-
- mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
- masked_video_latents_input = (
- torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
- )
- inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
- else:
- # Prepare mask latent variables
- video_length = video.shape[2]
- mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
- mask_condition = mask_condition.to(dtype=torch.float32)
- mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
-
- if num_channels_transformer != num_channels_latents:
- mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
- if masked_video_latents is None:
- masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
- else:
- masked_video = masked_video_latents
-
- _, masked_video_latents = self.prepare_mask_latents(
- None,
- masked_video,
- batch_size,
- height,
- width,
- self.vae.dtype,
- device,
- generator,
- do_classifier_free_guidance,
- noise_aug_strength=noise_aug_strength,
- )
- mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
- mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
-
- mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
- mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
-
- mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
- masked_video_latents_input = (
- torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
- )
-
- mask = rearrange(mask, "b c f h w -> b f c h w")
- mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
- masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
-
- inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
- else:
- mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
- mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
- mask = rearrange(mask, "b c f h w -> b f c h w")
-
- inpaint_latents = None
- else:
- if num_channels_transformer != num_channels_latents:
- mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
- masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
-
- mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
- masked_video_latents_input = (
- torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
- )
- inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
- else:
- mask = torch.zeros_like(init_video[:, :1])
- mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
- mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
- mask = rearrange(mask, "b c f h w -> b f c h w")
-
- inpaint_latents = None
-
- self.vae.to(torch.device("cpu"))
-
- if comfyui_progressbar:
- pbar.update(1)
-
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
-
- # 7. Create rotary embeds if required
- if context_schedule is not None:
- print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
- use_context_schedule = True
- from .context import get_context_scheduler
- context = get_context_scheduler(context_schedule)
- else:
- use_context_schedule = False
- print("context schedule disabled")
- # 7. Create rotary embeds if required
- image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
- if self.transformer.config.use_rotary_positional_embeddings
- else None
- )
- if tora is not None and do_classifier_free_guidance:
- video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
-
- if tora is not None:
- trajectory_length = tora["video_flow_features"].shape[1]
- logger.info(f"Tora trajectory length: {trajectory_length}")
- logger.info(f"Tora trajectory shape: {tora['video_flow_features'].shape}")
- logger.info(f"latents shape: {latents.shape}")
- if trajectory_length != latents.shape[1]:
- raise ValueError(f"Tora trajectory length {trajectory_length} does not match latent count {latents.shape[2]}")
- for module in self.transformer.fuser_list:
- for param in module.parameters():
- param.data = param.data.to(device)
-
- # 8. Denoising loop
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
-
- from ..latent_preview import prepare_callback
- callback = prepare_callback(self.transformer, num_inference_steps)
-
- with self.progress_bar(total=num_inference_steps) as progress_bar:
- # for DPM-solver++
- old_pred_original_sample = None
- for i, t in enumerate(timesteps):
- if self.interrupt:
- continue
-
- if use_context_schedule:
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latent_model_input.shape[0])
-
- context_queue = list(context(
- i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
- ))
- counter = torch.zeros_like(latent_model_input)
- noise_pred = torch.zeros_like(latent_model_input)
-
- current_step_percentage = i / num_inference_steps
-
- image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
- if self.transformer.config.use_rotary_positional_embeddings
- else None
- )
-
- for c in context_queue:
- partial_latent_model_input = latent_model_input[:, c, :, :, :]
- partial_inpaint_latents = inpaint_latents[:, c, :, :, :]
- partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :]
- if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]):
- if do_classifier_free_guidance:
- partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
- else:
- partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :]
- else:
- partial_video_flow_features = None
-
- # predict noise model_output
- noise_pred[:, c, :, :, :] += self.transformer(
- hidden_states=partial_latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- inpaint_latents=partial_inpaint_latents,
- video_flow_features=partial_video_flow_features
- )[0]
-
- counter[:, c, :, :, :] += 1
-
- noise_pred = noise_pred.float()
-
- noise_pred /= counter
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- if not isinstance(self.scheduler, CogVideoXDPMScheduler):
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- else:
- latents, old_pred_original_sample = self.scheduler.step(
- noise_pred,
- old_pred_original_sample,
- t,
- timesteps[i - 1] if i > 0 else None,
- latents,
- **extra_step_kwargs,
- return_dict=False,
- )
- latents = latents.to(prompt_embeds.dtype)
-
- # call the callback, if provided
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- if comfyui_progressbar:
- pbar.update(1)
-
- else:
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latent_model_input.shape[0])
-
- current_step_percentage = i / num_inference_steps
-
- # predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- inpaint_latents=inpaint_latents,
- video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
-
- )[0]
- noise_pred = noise_pred.float()
-
- # perform guidance
- if use_dynamic_cfg:
- self._guidance_scale = 1 + guidance_scale * (
- (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
- )
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- if not isinstance(self.scheduler, CogVideoXDPMScheduler):
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- else:
- latents, old_pred_original_sample = self.scheduler.step(
- noise_pred,
- old_pred_original_sample,
- t,
- timesteps[i - 1] if i > 0 else None,
- latents,
- **extra_step_kwargs,
- return_dict=False,
- )
- latents = latents.to(prompt_embeds.dtype)
-
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- if comfyui_progressbar:
- if callback is not None:
- callback(i, latents.detach()[-1], None, num_inference_steps)
- else:
- pbar.update(1)
-
- # Offload all models
- self.maybe_free_model_hooks()
-
- return latents
\ No newline at end of file
diff --git a/cogvideox_fun/transformer_3d.py b/cogvideox_fun/transformer_3d.py
deleted file mode 100644
index 5b6fef9..0000000
--- a/cogvideox_fun/transformer_3d.py
+++ /dev/null
@@ -1,823 +0,0 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Dict, Optional, Tuple, Union
-
-import os
-import json
-import torch
-import glob
-import torch.nn.functional as F
-from torch import nn
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.utils import is_torch_version, logging
-from diffusers.utils.torch_utils import maybe_allow_in_graph
-from diffusers.models.attention import Attention, FeedForward
-from diffusers.models.attention_processor import AttentionProcessor#, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
-from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
-from diffusers.models.modeling_outputs import Transformer2DModelOutput
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-from einops import rearrange
-try:
- from sageattention import sageattn
- SAGEATTN_IS_AVAILABLE = True
-except:
- SAGEATTN_IS_AVAILABLE = False
-
-def fft(tensor):
- tensor_fft = torch.fft.fft2(tensor)
- tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
- B, C, H, W = tensor.size()
- radius = min(H, W) // 5
-
- Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
- center_x, center_y = W // 2, H // 2
- mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
- low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
- high_freq_mask = ~low_freq_mask
-
- low_freq_fft = tensor_fft_shifted * low_freq_mask
- high_freq_fft = tensor_fft_shifted * high_freq_mask
-
- return low_freq_fft, high_freq_fft
-
-class CogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- attention_mode: Optional[str] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- from diffusers.models.embeddings import apply_rotary_emb
-
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
- if not attn.is_cross_attention:
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
- if attention_mode == "sageattn":
- if SAGEATTN_IS_AVAILABLE:
- hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
- else:
- raise ImportError("sageattn not found")
- else:
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-class CogVideoXPatchEmbed(nn.Module):
- def __init__(
- self,
- patch_size: int = 2,
- in_channels: int = 16,
- embed_dim: int = 1920,
- text_embed_dim: int = 4096,
- bias: bool = True,
- ) -> None:
- super().__init__()
- self.patch_size = patch_size
-
- self.proj = nn.Conv2d(
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
- )
- self.text_proj = nn.Linear(text_embed_dim, embed_dim)
-
- def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
- r"""
- Args:
- text_embeds (`torch.Tensor`):
- Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
- image_embeds (`torch.Tensor`):
- Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
- """
- text_embeds = self.text_proj(text_embeds)
-
- batch, num_frames, channels, height, width = image_embeds.shape
- image_embeds = image_embeds.reshape(-1, channels, height, width)
- image_embeds = self.proj(image_embeds)
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
-
- embeds = torch.cat(
- [text_embeds, image_embeds], dim=1
- ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
- return embeds
-
-@maybe_allow_in_graph
-class CogVideoXBlock(nn.Module):
- r"""
- Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
-
- Parameters:
- dim (`int`):
- The number of channels in the input and output.
- num_attention_heads (`int`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`):
- The number of channels in each head.
- time_embed_dim (`int`):
- The number of channels in timestep embedding.
- dropout (`float`, defaults to `0.0`):
- The dropout probability to use.
- activation_fn (`str`, defaults to `"gelu-approximate"`):
- Activation function to be used in feed-forward.
- attention_bias (`bool`, defaults to `False`):
- Whether or not to use bias in attention projection layers.
- qk_norm (`bool`, defaults to `True`):
- Whether or not to use normalization after query and key projections in Attention.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether to use learnable elementwise affine parameters for normalization.
- norm_eps (`float`, defaults to `1e-5`):
- Epsilon value for normalization layers.
- final_dropout (`bool` defaults to `False`):
- Whether to apply a final dropout after the last feed-forward layer.
- ff_inner_dim (`int`, *optional*, defaults to `None`):
- Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
- ff_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Feed-forward layer.
- attention_out_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Attention output projection layer.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- time_embed_dim: int,
- dropout: float = 0.0,
- activation_fn: str = "gelu-approximate",
- attention_bias: bool = False,
- qk_norm: bool = True,
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- final_dropout: bool = True,
- ff_inner_dim: Optional[int] = None,
- ff_bias: bool = True,
- attention_out_bias: bool = True,
- attention_mode: Optional[str] = None,
- ):
- super().__init__()
-
- # 1. Self Attention
- self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.attn1 = Attention(
- query_dim=dim,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- qk_norm="layer_norm" if qk_norm else None,
- eps=1e-6,
- bias=attention_bias,
- out_bias=attention_out_bias,
- processor=CogVideoXAttnProcessor2_0(),
- )
-
- # 2. Feed Forward
- self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- inner_dim=ff_inner_dim,
- bias=ff_bias,
- )
- self.cached_hidden_states = []
- self.cached_encoder_hidden_states = []
- self.attention_mode = attention_mode
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- temb: torch.Tensor,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- video_flow_feature: Optional[torch.Tensor] = None,
- fuser=None,
- block_use_fastercache=False,
- fastercache_counter=0,
- fastercache_start_step=15,
- fastercache_device="cuda:0",
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- # norm & modulate
- norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
- hidden_states, encoder_hidden_states, temb
- )
- # Tora Motion-guidance Fuser
- if video_flow_feature is not None:
- H, W = video_flow_feature.shape[-2:]
- T = norm_hidden_states.shape[1] // H // W
- h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
- h = fuser(h, video_flow_feature.to(h), T=T)
- norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
- del h, fuser
-
- #region fastercache
- if block_use_fastercache:
- B = norm_hidden_states.shape[0]
- if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
- attn_hidden_states = (
- self.cached_hidden_states[1][:B] +
- (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
- * 0.3
- ).to(norm_hidden_states.device, non_blocking=True)
- attn_encoder_hidden_states = (
- self.cached_encoder_hidden_states[1][:B] +
- (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
- * 0.3
- ).to(norm_hidden_states.device, non_blocking=True)
- else:
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=norm_encoder_hidden_states,
- image_rotary_emb=image_rotary_emb,
- attention_mode=self.attention_mode,
- )
- if fastercache_counter == fastercache_start_step:
- self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
- self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
- elif fastercache_counter > fastercache_start_step:
- self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
- self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
- else:
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=norm_encoder_hidden_states,
- image_rotary_emb=image_rotary_emb,
- attention_mode=self.attention_mode,
- )
-
- hidden_states = hidden_states + gate_msa * attn_hidden_states
- encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
-
- # norm & modulate
- norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
- hidden_states, encoder_hidden_states, temb
- )
-
- # feed-forward
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
- ff_output = self.ff(norm_hidden_states)
-
- hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
- encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
-
- return hidden_states, encoder_hidden_states
-
-
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
- """
- A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
-
- Parameters:
- num_attention_heads (`int`, defaults to `30`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`, defaults to `64`):
- The number of channels in each head.
- in_channels (`int`, defaults to `16`):
- The number of channels in the input.
- out_channels (`int`, *optional*, defaults to `16`):
- The number of channels in the output.
- flip_sin_to_cos (`bool`, defaults to `True`):
- Whether to flip the sin to cos in the time embedding.
- time_embed_dim (`int`, defaults to `512`):
- Output dimension of timestep embeddings.
- text_embed_dim (`int`, defaults to `4096`):
- Input dimension of text embeddings from the text encoder.
- num_layers (`int`, defaults to `30`):
- The number of layers of Transformer blocks to use.
- dropout (`float`, defaults to `0.0`):
- The dropout probability to use.
- attention_bias (`bool`, defaults to `True`):
- Whether or not to use bias in the attention projection layers.
- sample_width (`int`, defaults to `90`):
- The width of the input latents.
- sample_height (`int`, defaults to `60`):
- The height of the input latents.
- sample_frames (`int`, defaults to `49`):
- The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
- instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
- but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
- K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
- patch_size (`int`, defaults to `2`):
- The size of the patches to use in the patch embedding layer.
- temporal_compression_ratio (`int`, defaults to `4`):
- The compression ratio across the temporal dimension. See documentation for `sample_frames`.
- max_text_seq_length (`int`, defaults to `226`):
- The maximum sequence length of the input text embeddings.
- activation_fn (`str`, defaults to `"gelu-approximate"`):
- Activation function to use in feed-forward.
- timestep_activation_fn (`str`, defaults to `"silu"`):
- Activation function to use when generating the timestep embeddings.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether or not to use elementwise affine in normalization layers.
- norm_eps (`float`, defaults to `1e-5`):
- The epsilon value to use in normalization layers.
- spatial_interpolation_scale (`float`, defaults to `1.875`):
- Scaling factor to apply in 3D positional embeddings across spatial dimensions.
- temporal_interpolation_scale (`float`, defaults to `1.0`):
- Scaling factor to apply in 3D positional embeddings across temporal dimensions.
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- num_attention_heads: int = 30,
- attention_head_dim: int = 64,
- in_channels: int = 16,
- out_channels: Optional[int] = 16,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- time_embed_dim: int = 512,
- text_embed_dim: int = 4096,
- num_layers: int = 30,
- dropout: float = 0.0,
- attention_bias: bool = True,
- sample_width: int = 90,
- sample_height: int = 60,
- sample_frames: int = 49,
- patch_size: int = 2,
- temporal_compression_ratio: int = 4,
- max_text_seq_length: int = 226,
- activation_fn: str = "gelu-approximate",
- timestep_activation_fn: str = "silu",
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- spatial_interpolation_scale: float = 1.875,
- temporal_interpolation_scale: float = 1.0,
- use_rotary_positional_embeddings: bool = False,
- add_noise_in_inpaint_model: bool = False,
- attention_mode: Optional[str] = None,
- ):
- super().__init__()
- inner_dim = num_attention_heads * attention_head_dim
-
- post_patch_height = sample_height // patch_size
- post_patch_width = sample_width // patch_size
- post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
- self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
- self.post_patch_height = post_patch_height
- self.post_patch_width = post_patch_width
- self.post_time_compression_frames = post_time_compression_frames
- self.patch_size = patch_size
-
- # 1. Patch embedding
- self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
- self.embedding_dropout = nn.Dropout(dropout)
-
- # 2. 3D positional embeddings
- spatial_pos_embedding = get_3d_sincos_pos_embed(
- inner_dim,
- (post_patch_width, post_patch_height),
- post_time_compression_frames,
- spatial_interpolation_scale,
- temporal_interpolation_scale,
- )
- spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
- pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
- pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
- self.register_buffer("pos_embedding", pos_embedding, persistent=False)
-
- # 3. Time embeddings
- self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
-
- # 4. Define spatio-temporal transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- CogVideoXBlock(
- dim=inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- time_embed_dim=time_embed_dim,
- dropout=dropout,
- activation_fn=activation_fn,
- attention_bias=attention_bias,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- )
- for _ in range(num_layers)
- ]
- )
- self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
-
- # 5. Output blocks
- self.norm_out = AdaLayerNorm(
- embedding_dim=time_embed_dim,
- output_dim=2 * inner_dim,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- chunk_dim=1,
- )
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
-
- self.gradient_checkpointing = False
-
- self.fuser_list = None
-
- self.use_fastercache = False
- self.fastercache_counter = 0
- self.fastercache_start_step = 15
- self.fastercache_lf_step = 40
- self.fastercache_hf_step = 30
- self.fastercache_device = "cuda"
- self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
- self.attention_mode = attention_mode
-
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- timestep: Union[int, float, torch.LongTensor],
- timestep_cond: Optional[torch.Tensor] = None,
- inpaint_latents: Optional[torch.Tensor] = None,
- control_latents: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- video_flow_features: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ):
- batch_size, num_frames, channels, height, width = hidden_states.shape
-
- # 1. Time embedding
- timesteps = timestep
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=hidden_states.dtype)
- emb = self.time_embedding(t_emb, timestep_cond)
-
- # 2. Patch embedding
- if inpaint_latents is not None:
- hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
- if control_latents is not None:
- hidden_states = torch.concat([hidden_states, control_latents], 2)
- hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
-
- # 3. Position embedding
- text_seq_length = encoder_hidden_states.shape[1]
- if not self.config.use_rotary_positional_embeddings:
- seq_length = height * width * num_frames // (self.config.patch_size**2)
- # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
- pos_embeds = self.pos_embedding
- emb_size = hidden_states.size()[-1]
- pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
- pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
- pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
- pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
- hidden_states = hidden_states + pos_embeds
- hidden_states = self.embedding_dropout(hidden_states)
-
- encoder_hidden_states = hidden_states[:, :text_seq_length]
- hidden_states = hidden_states[:, text_seq_length:]
-
- if self.use_fastercache:
- self.fastercache_counter+=1
- if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
- # 4. Transformer blocks
- for i, block in enumerate(self.transformer_blocks):
- hidden_states, encoder_hidden_states = block(
- hidden_states=hidden_states[:1],
- encoder_hidden_states=encoder_hidden_states[:1],
- temb=emb[:1],
- image_rotary_emb=image_rotary_emb,
- video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
- fuser = self.fuser_list[i] if self.fuser_list is not None else None,
- block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
- fastercache_start_step = self.fastercache_start_step,
- fastercache_counter = self.fastercache_counter,
- fastercache_device = self.fastercache_device
- )
-
- if not self.config.use_rotary_positional_embeddings:
- # CogVideoX-2B
- hidden_states = self.norm_final(hidden_states)
- else:
- # CogVideoX-5B
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- hidden_states = self.norm_final(hidden_states)
- hidden_states = hidden_states[:, text_seq_length:]
-
- # 5. Final block
- hidden_states = self.norm_out(hidden_states, temb=emb[:1])
- hidden_states = self.proj_out(hidden_states)
-
- # 6. Unpatchify
- p = self.config.patch_size
- output = hidden_states.reshape(1, num_frames, height // p, width // p, channels, p, p)
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
-
- (bb, tt, cc, hh, ww) = output.shape
- cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
- lf_c, hf_c = fft(cond.float())
- #lf_step = 40
- #hf_step = 30
- if self.fastercache_counter <= self.fastercache_lf_step:
- self.delta_lf = self.delta_lf * 1.1
- if self.fastercache_counter >= self.fastercache_hf_step:
- self.delta_hf = self.delta_hf * 1.1
-
- new_hf_uc = self.delta_hf + hf_c
- new_lf_uc = self.delta_lf + lf_c
-
- combine_uc = new_lf_uc + new_hf_uc
- combined_fft = torch.fft.ifftshift(combine_uc)
- recovered_uncond = torch.fft.ifft2(combined_fft).real
- recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
- output = torch.cat([output, recovered_uncond])
- else:
- # 4. Transformer blocks
- for i, block in enumerate(self.transformer_blocks):
- hidden_states, encoder_hidden_states = block(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=emb,
- image_rotary_emb=image_rotary_emb,
- video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
- fuser = self.fuser_list[i] if self.fuser_list is not None else None,
- block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
- fastercache_counter = self.fastercache_counter,
- fastercache_start_step = self.fastercache_start_step,
- fastercache_device = self.fastercache_device
- )
-
- if not self.config.use_rotary_positional_embeddings:
- # CogVideoX-2B
- hidden_states = self.norm_final(hidden_states)
- else:
- # CogVideoX-5B
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- hidden_states = self.norm_final(hidden_states)
- hidden_states = hidden_states[:, text_seq_length:]
-
- # 5. Final block
- hidden_states = self.norm_out(hidden_states, temb=emb)
- hidden_states = self.proj_out(hidden_states)
-
- # 6. Unpatchify
- p = self.config.patch_size
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
-
- if self.fastercache_counter >= self.fastercache_start_step + 1:
- (bb, tt, cc, hh, ww) = output.shape
- cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
- uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
-
- lf_c, hf_c = fft(cond)
- lf_uc, hf_uc = fft(uncond)
-
- self.delta_lf = lf_uc - lf_c
- self.delta_hf = hf_uc - hf_c
-
-
- if not return_dict:
- return (output,)
- return Transformer2DModelOutput(sample=output)
-
- @classmethod
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
- if subfolder is not None:
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
-
- config_file = os.path.join(pretrained_model_path, 'config.json')
- if not os.path.isfile(config_file):
- raise RuntimeError(f"{config_file} does not exist")
- with open(config_file, "r") as f:
- config = json.load(f)
-
- from diffusers.utils import WEIGHTS_NAME
- model = cls.from_config(config, **transformer_additional_kwargs)
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
- if os.path.exists(model_file):
- state_dict = torch.load(model_file, map_location="cpu")
- elif os.path.exists(model_file_safetensors):
- from safetensors.torch import load_file, safe_open
- state_dict = load_file(model_file_safetensors)
- else:
- from safetensors.torch import load_file, safe_open
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
- state_dict = {}
- for model_file_safetensors in model_files_safetensors:
- _state_dict = load_file(model_file_safetensors)
- for key in _state_dict:
- state_dict[key] = _state_dict[key]
-
- if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
- new_shape = model.state_dict()['patch_embed.proj.weight'].size()
- if len(new_shape) == 5:
- state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
- state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
- else:
- if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
- model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
- model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
- else:
- model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
-
- tmp_state_dict = {}
- for key in state_dict:
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
- tmp_state_dict[key] = state_dict[key]
- else:
- print(key, "Size don't match, skip")
- state_dict = tmp_state_dict
-
- m, u = model.load_state_dict(state_dict, strict=False)
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
- print(m)
-
- params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
- print(f"### Mamba Parameters: {sum(params) / 1e6} M")
-
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
-
- return model
\ No newline at end of file
diff --git a/cogvideox_fun/utils.py b/cogvideox_fun/utils.py
index f790161..9f670ec 100644
--- a/cogvideox_fun/utils.py
+++ b/cogvideox_fun/utils.py
@@ -1,26 +1,6 @@
-import os
-import gc
import numpy as np
-import torch
from PIL import Image
-# Copyright (c) OpenMMLab. All rights reserved.
-
-def tensor2pil(image):
- return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
-
-def numpy2pil(image):
- return Image.fromarray(np.clip(255. * image, 0, 255).astype(np.uint8))
-
-def to_pil(image):
- if isinstance(image, Image.Image):
- return image
- if isinstance(image, torch.Tensor):
- return tensor2pil(image)
- if isinstance(image, np.ndarray):
- return numpy2pil(image)
- raise ValueError(f"Cannot convert {type(image)} to PIL.Image")
-
ASPECT_RATIO_512 = {
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
@@ -54,126 +34,10 @@ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_5
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return ratios[closest_ratio], float(closest_ratio)
-
def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
target_pixels = int(base_resolution) * int(base_resolution)
original_width, original_height = Image.open(image).size
ratio = (target_pixels / (original_width * original_height)) ** 0.5
width_slider = round(original_width * ratio)
height_slider = round(original_height * ratio)
- return height_slider, width_slider
-
-def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
- if validation_image_start is not None and validation_image_end is not None:
- if type(validation_image_start) is str and os.path.isfile(validation_image_start):
- image_start = clip_image = Image.open(validation_image_start).convert("RGB")
- image_start = image_start.resize([sample_size[1], sample_size[0]])
- clip_image = clip_image.resize([sample_size[1], sample_size[0]])
- else:
- image_start = clip_image = validation_image_start
- image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
- clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
-
- if type(validation_image_end) is str and os.path.isfile(validation_image_end):
- image_end = Image.open(validation_image_end).convert("RGB")
- image_end = image_end.resize([sample_size[1], sample_size[0]])
- else:
- image_end = validation_image_end
- image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
-
- if type(image_start) is list:
- clip_image = clip_image[0]
- start_video = torch.cat(
- [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
- dim=2
- )
- input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
- input_video[:, :, :len(image_start)] = start_video
-
- input_video_mask = torch.zeros_like(input_video[:, :1])
- input_video_mask[:, :, len(image_start):] = 255
- else:
- input_video = torch.tile(
- torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
- [1, 1, video_length, 1, 1]
- )
- input_video_mask = torch.zeros_like(input_video[:, :1])
- input_video_mask[:, :, 1:] = 255
-
- if type(image_end) is list:
- image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
- end_video = torch.cat(
- [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
- dim=2
- )
- input_video[:, :, -len(end_video):] = end_video
-
- input_video_mask[:, :, -len(image_end):] = 0
- else:
- image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
- input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
- input_video_mask[:, :, -1:] = 0
-
- input_video = input_video / 255
-
- elif validation_image_start is not None:
- if type(validation_image_start) is str and os.path.isfile(validation_image_start):
- image_start = clip_image = Image.open(validation_image_start).convert("RGB")
- image_start = image_start.resize([sample_size[1], sample_size[0]])
- clip_image = clip_image.resize([sample_size[1], sample_size[0]])
- else:
- image_start = clip_image = validation_image_start
- image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
- clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
- image_end = None
-
- if type(image_start) is list:
- clip_image = clip_image[0]
- start_video = torch.cat(
- [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
- dim=2
- )
- input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
- input_video[:, :, :len(image_start)] = start_video
- input_video = input_video / 255
-
- input_video_mask = torch.zeros_like(input_video[:, :1])
- input_video_mask[:, :, len(image_start):] = 255
- else:
- input_video = torch.tile(
- torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
- [1, 1, video_length, 1, 1]
- ) / 255
- input_video_mask = torch.zeros_like(input_video[:, :1])
- input_video_mask[:, :, 1:, ] = 255
- else:
- image_start = None
- image_end = None
- input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
- input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
- clip_image = None
-
- del image_start
- del image_end
- gc.collect()
-
- return input_video, input_video_mask, clip_image
-
-def get_video_to_video_latent(input_video_path, video_length, sample_size, validation_video_mask=None):
- input_video = input_video_path
-
- input_video = torch.from_numpy(np.array(input_video))[:video_length]
- input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
-
- if validation_video_mask is not None:
- validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
- input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
-
- input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
- input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
- input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
- else:
- input_video_mask = torch.zeros_like(input_video[:, :1])
- input_video_mask[:, :, :] = 255
-
- return input_video, input_video_mask, None
\ No newline at end of file
+ return height_slider, width_slider
\ No newline at end of file
diff --git a/cogvideox_fun/context.py b/context.py
similarity index 100%
rename from cogvideox_fun/context.py
rename to context.py
diff --git a/convert_weight_sat2hf.py b/convert_weight_sat2hf.py
deleted file mode 100644
index 545925b..0000000
--- a/convert_weight_sat2hf.py
+++ /dev/null
@@ -1,303 +0,0 @@
-"""
-
-The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
-This script supports the conversion of the following models:
-- CogVideoX-2B
-- CogVideoX-5B, CogVideoX-5B-I2V
-- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
-
-Original Script:
-https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
-
-"""
-import argparse
-from typing import Any, Dict
-
-import torch
-from transformers import T5EncoderModel, T5Tokenizer
-
-from diffusers import (
- AutoencoderKLCogVideoX,
- CogVideoXDDIMScheduler,
- CogVideoXImageToVideoPipeline,
- CogVideoXPipeline,
- #CogVideoXTransformer3DModel,
-)
-from custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
-
-
-def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
- to_q_key = key.replace("query_key_value", "to_q")
- to_k_key = key.replace("query_key_value", "to_k")
- to_v_key = key.replace("query_key_value", "to_v")
- to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
- state_dict[to_q_key] = to_q
- state_dict[to_k_key] = to_k
- state_dict[to_v_key] = to_v
- state_dict.pop(key)
-
-
-def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
- layer_id, weight_or_bias = key.split(".")[-2:]
-
- if "query" in key:
- new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
- elif "key" in key:
- new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
-
- state_dict[new_key] = state_dict.pop(key)
-
-
-def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
- layer_id, _, weight_or_bias = key.split(".")[-3:]
-
- weights_or_biases = state_dict[key].chunk(12, dim=0)
- norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
- norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
-
- norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
- state_dict[norm1_key] = norm1_weights_or_biases
-
- norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
- state_dict[norm2_key] = norm2_weights_or_biases
-
- state_dict.pop(key)
-
-
-def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
- state_dict.pop(key)
-
-
-def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
- key_split = key.split(".")
- layer_index = int(key_split[2])
- replace_layer_index = 4 - 1 - layer_index
-
- key_split[1] = "up_blocks"
- key_split[2] = str(replace_layer_index)
- new_key = ".".join(key_split)
-
- state_dict[new_key] = state_dict.pop(key)
-
-
-TRANSFORMER_KEYS_RENAME_DICT = {
- "transformer.final_layernorm": "norm_final",
- "transformer": "transformer_blocks",
- "attention": "attn1",
- "mlp": "ff.net",
- "dense_h_to_4h": "0.proj",
- "dense_4h_to_h": "2",
- ".layers": "",
- "dense": "to_out.0",
- "input_layernorm": "norm1.norm",
- "post_attn1_layernorm": "norm2.norm",
- "time_embed.0": "time_embedding.linear_1",
- "time_embed.2": "time_embedding.linear_2",
- "mixins.patch_embed": "patch_embed",
- "mixins.final_layer.norm_final": "norm_out.norm",
- "mixins.final_layer.linear": "proj_out",
- "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
- "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
-}
-
-TRANSFORMER_SPECIAL_KEYS_REMAP = {
- "query_key_value": reassign_query_key_value_inplace,
- "query_layernorm_list": reassign_query_key_layernorm_inplace,
- "key_layernorm_list": reassign_query_key_layernorm_inplace,
- "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
- "embed_tokens": remove_keys_inplace,
- "freqs_sin": remove_keys_inplace,
- "freqs_cos": remove_keys_inplace,
- "position_embedding": remove_keys_inplace,
-}
-
-VAE_KEYS_RENAME_DICT = {
- "block.": "resnets.",
- "down.": "down_blocks.",
- "downsample": "downsamplers.0",
- "upsample": "upsamplers.0",
- "nin_shortcut": "conv_shortcut",
- "encoder.mid.block_1": "encoder.mid_block.resnets.0",
- "encoder.mid.block_2": "encoder.mid_block.resnets.1",
- "decoder.mid.block_1": "decoder.mid_block.resnets.0",
- "decoder.mid.block_2": "decoder.mid_block.resnets.1",
-}
-
-VAE_SPECIAL_KEYS_REMAP = {
- "loss": remove_keys_inplace,
- "up.": replace_up_keys_inplace,
-}
-
-TOKENIZER_MAX_LENGTH = 226
-
-
-def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
- state_dict = saved_dict
- if "model" in saved_dict.keys():
- state_dict = state_dict["model"]
- if "module" in saved_dict.keys():
- state_dict = state_dict["module"]
- if "state_dict" in saved_dict.keys():
- state_dict = state_dict["state_dict"]
- return state_dict
-
-
-def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
- state_dict[new_key] = state_dict.pop(old_key)
-
-
-def convert_transformer(
- ckpt_path: str,
- num_layers: int,
- num_attention_heads: int,
- use_rotary_positional_embeddings: bool,
- i2v: bool,
- dtype: torch.dtype,
-):
- PREFIX_KEY = "model.diffusion_model."
-
- original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
- transformer = CogVideoXTransformer3DModel(
- in_channels=32 if i2v else 16,
- num_layers=num_layers,
- num_attention_heads=num_attention_heads,
- use_rotary_positional_embeddings=use_rotary_positional_embeddings,
- use_learned_positional_embeddings=i2v,
- ).to(dtype=dtype)
-
- for key in list(original_state_dict.keys()):
- new_key = key[len(PREFIX_KEY):]
- for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
- new_key = new_key.replace(replace_key, rename_key)
- update_state_dict_inplace(original_state_dict, key, new_key)
-
- for key in list(original_state_dict.keys()):
- for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
- if special_key not in key:
- continue
- handler_fn_inplace(key, original_state_dict)
- transformer.load_state_dict(original_state_dict, strict=True)
- return transformer
-
-
-def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
- original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
- vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
-
- for key in list(original_state_dict.keys()):
- new_key = key[:]
- for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
- new_key = new_key.replace(replace_key, rename_key)
- update_state_dict_inplace(original_state_dict, key, new_key)
-
- for key in list(original_state_dict.keys()):
- for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
- if special_key not in key:
- continue
- handler_fn_inplace(key, original_state_dict)
-
- vae.load_state_dict(original_state_dict, strict=True)
- return vae
-
-
-def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
- )
- parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
- parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
- parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
- parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
- parser.add_argument(
- "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
- )
- parser.add_argument(
- "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
- )
- # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
- parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
- # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
- parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
- # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
- parser.add_argument(
- "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
- )
- # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
- parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
- # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
- parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
- parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
- return parser.parse_args()
-
-
-if __name__ == "__main__":
- args = get_args()
-
- transformer = None
- vae = None
-
- if args.fp16 and args.bf16:
- raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
-
- dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
-
- if args.transformer_ckpt_path is not None:
- transformer = convert_transformer(
- args.transformer_ckpt_path,
- args.num_layers,
- args.num_attention_heads,
- args.use_rotary_positional_embeddings,
- args.i2v,
- dtype,
- )
- if args.vae_ckpt_path is not None:
- vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
-
- #text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
- #tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
- #text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
-
- # Apparently, the conversion does not work anymore without this :shrug:
- #for param in text_encoder.parameters():
- # param.data = param.data.contiguous()
-
- scheduler = CogVideoXDDIMScheduler.from_config(
- {
- "snr_shift_scale": args.snr_shift_scale,
- "beta_end": 0.012,
- "beta_schedule": "scaled_linear",
- "beta_start": 0.00085,
- "clip_sample": False,
- "num_train_timesteps": 1000,
- "prediction_type": "v_prediction",
- "rescale_betas_zero_snr": True,
- "set_alpha_to_one": True,
- "timestep_spacing": "trailing",
- }
- )
- if args.i2v:
- pipeline_cls = CogVideoXImageToVideoPipeline
- else:
- pipeline_cls = CogVideoXPipeline
-
- pipe = pipeline_cls(
- tokenizer=None,
- text_encoder=None,
- vae=vae,
- transformer=transformer,
- scheduler=scheduler,
- )
-
- if args.fp16:
- pipe = pipe.to(dtype=torch.float16)
- if args.bf16:
- pipe = pipe.to(dtype=torch.bfloat16)
-
- # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
- # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
- # is either fp16/bf16 here).
-
- # This is necessary This is necessary for users with insufficient memory,
- # such as those using Colab and notebooks, as it can save some memory used for model loading.
- pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py
index 50b0f25..89c72aa 100644
--- a/custom_cogvideox_transformer_3d.py
+++ b/custom_cogvideox_transformer_3d.py
@@ -76,7 +76,6 @@ class CogVideoXAttnProcessor2_0:
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
- #@torch.compiler.disable()
def __call__(
self,
attn: Attention,
diff --git a/model_loading.py b/model_loading.py
index e627351..c77e3c5 100644
--- a/model_loading.py
+++ b/model_loading.py
@@ -43,11 +43,8 @@ from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
from .pipeline_cogvideox import CogVideoXPipeline
from contextlib import nullcontext
-from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun
-from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun
-
-from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
-from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
from .utils import remove_specific_blocks, log
from comfy.utils import load_torch_file
@@ -121,8 +118,7 @@ class DownloadAndLoadCogVideoModel:
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
),
- "fp8_transformer": (['disabled', 'enabled', 'fastmode', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
- "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
+ "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fastmode', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
@@ -132,13 +128,13 @@ class DownloadAndLoadCogVideoModel:
}
}
- RETURN_TYPES = ("COGVIDEOPIPE",)
- RETURN_NAMES = ("cogvideo_pipe", )
+ RETURN_TYPES = ("COGVIDEOMODEL", "VAE",)
+ RETURN_NAMES = ("model", "vae", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
- def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
+ def loadmodel(self, model, precision, quantization="disabled", compile="disabled",
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
@@ -215,12 +211,7 @@ class DownloadAndLoadCogVideoModel:
local_dir_use_symlinks=False,
)
- #transformer
- if "Fun" in model:
- transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
- else:
- transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
-
+ transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = transformer.to(dtype).to(transformer_load_device)
if "1.5" in model:
@@ -235,17 +226,17 @@ class DownloadAndLoadCogVideoModel:
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
# VAE
- if "Fun" in model:
- vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
- if "Pose" in model:
- pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
- else:
- pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
- else:
- vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
- pipe = CogVideoXPipeline(vae, transformer, scheduler)
- if "cogvideox-2b-img2vid" in model:
- pipe.input_with_padding = False
+ vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
+
+ #pipeline
+ pipe = CogVideoXPipeline(
+ transformer,
+ scheduler,
+ dtype=dtype,
+ is_fun_inpaint=True if "fun" in model.lower() and "pose" not in model.lower() else False
+ )
+ if "cogvideox-2b-img2vid" in model:
+ pipe.input_with_padding = False
#LoRAs
if lora is not None:
@@ -281,8 +272,19 @@ class DownloadAndLoadCogVideoModel:
lora_scale = lora_scale / lora_rank
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
+ if "fused" in attention_mode:
+ from diffusers.models.attention import Attention
+ transformer.fuse_qkv_projections = True
+ for module in transformer.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+ transformer.attention_mode = attention_mode
+
+ if compile_args is not None:
+ pipe.transformer.to(memory_format=torch.channels_last)
+
#fp8
- if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
+ if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode":
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
if "1.5" in model:
params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"})
@@ -290,13 +292,20 @@ class DownloadAndLoadCogVideoModel:
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
- if fp8_transformer == "fastmode":
+ if quantization == "fp8_e4m3fn_fastmode":
from .fp8_optimization import convert_fp8_linear
if "1.5" in model:
params_to_keep.update({"ff"}) #otherwise NaNs
convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep)
+
+ # compilation
+ if compile_args is not None:
+ torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
+ for i, block in enumerate(pipe.transformer.transformer_blocks):
+ if "CogVideoXBlock" in str(block):
+ pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
- elif "torchao" in fp8_transformer:
+ if "torchao" in quantization:
try:
from torchao.quantization import (
quantize_,
@@ -313,14 +322,14 @@ class DownloadAndLoadCogVideoModel:
return isinstance(module, nn.Linear)
return False
- if "fp6" in fp8_transformer: #slower for some reason on 4090
+ if "fp6" in quantization: #slower for some reason on 4090
quant_func = fpx_weight_only(3, 2)
- elif "fp8dq" in fp8_transformer: #very fast on 4090 when compiled
+ elif "fp8dq" in quantization: #very fast on 4090 when compiled
quant_func = float8_dynamic_activation_float8_weight()
- elif 'fp8dqrow' in fp8_transformer:
+ elif 'fp8dqrow' in quantization:
from torchao.quantization.quant_api import PerRow
quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow())
- elif 'int8dq' in fp8_transformer:
+ elif 'int8dq' in quantization:
quant_func = int8_dynamic_activation_int8_weight()
for i, block in enumerate(pipe.transformer.transformer_blocks):
@@ -365,41 +374,19 @@ class DownloadAndLoadCogVideoModel:
# (3): Dropout(p=0.0, inplace=False)
# )
# )
- # )
+ # )
+
+ # if compile == "onediff":
+ # from onediffx import compile_pipe
+ # os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
- # compilation
- if compile == "torch":
- #pipe.transformer.to(memory_format=torch.channels_last)
- if compile_args is not None:
- torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
- for i, block in enumerate(pipe.transformer.transformer_blocks):
- if "CogVideoXBlock" in str(block):
- pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
- else:
- for i, block in enumerate(pipe.transformer.transformer_blocks):
- if "CogVideoXBlock" in str(block):
- pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
-
- transformer.attention_mode = attention_mode
-
- if "fused" in attention_mode:
- from diffusers.models.attention import Attention
- transformer.fuse_qkv_projections = True
- for module in transformer.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- elif compile == "onediff":
- from onediffx import compile_pipe
- os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
-
- pipe = compile_pipe(
- pipe,
- backend="nexfort",
- options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
- ignores=["vae"],
- fuse_qkv_projections= False,
- )
+ # pipe = compile_pipe(
+ # pipe,
+ # backend="nexfort",
+ # options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
+ # ignores=["vae"],
+ # fuse_qkv_projections= False,
+ # )
pipeline = {
"pipe": pipe,
@@ -412,7 +399,7 @@ class DownloadAndLoadCogVideoModel:
"model_name": model,
}
- return (pipeline,)
+ return (pipeline, vae)
#region GGUF
class DownloadAndLoadCogVideoGGUFModel:
@classmethod
@@ -444,8 +431,8 @@ class DownloadAndLoadCogVideoGGUFModel:
}
}
- RETURN_TYPES = ("COGVIDEOPIPE",)
- RETURN_NAMES = ("cogvideo_pipe", )
+ RETURN_TYPES = ("COGVIDEOMODEL", "VAE",)
+ RETURN_NAMES = ("model", "vae",)
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
@@ -486,7 +473,6 @@ class DownloadAndLoadCogVideoGGUFModel:
with open(transformer_path) as f:
transformer_config = json.load(f)
-
from . import mz_gguf_loader
import importlib
@@ -498,7 +484,6 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["in_channels"] = 32
else:
transformer_config["in_channels"] = 33
- transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
elif "I2V" in model or "Interpolation" in model:
transformer_config["in_channels"] = 32
if "1_5" in model:
@@ -508,10 +493,10 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["patch_bias"] = False
transformer_config["sample_height"] = 300
transformer_config["sample_width"] = 300
- transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
else:
transformer_config["in_channels"] = 16
- transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
+
+ transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
if "2b" in model:
@@ -564,60 +549,25 @@ class DownloadAndLoadCogVideoGGUFModel:
with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f:
vae_config = json.load(f)
+ #VAE
vae_sd = load_torch_file(vae_path)
- if "fun" in model:
- vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
- vae.load_state_dict(vae_sd)
- if "Pose" in model:
- pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
- else:
- pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
- else:
- vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
- vae.load_state_dict(vae_sd)
- pipe = CogVideoXPipeline(vae, transformer, scheduler)
+ vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
+ vae.load_state_dict(vae_sd)
+ del vae_sd
+ pipe = CogVideoXPipeline(transformer, scheduler, dtype=vae_dtype)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
sd = load_torch_file(gguf_path)
-
- # #LoRAs
- # if lora is not None:
- # if "fun" in model.lower():
- # raise NotImplementedError("LoRA with GGUF is not supported for Fun models")
- # from .lora_utils import merge_lora#, load_lora_into_transformer
- # #for l in lora:
- # # log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
- # # pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"])
- # else:
- # adapter_list = []
- # adapter_weights = []
- # for l in lora:
- # lora_sd = load_torch_file(l["path"])
- # for key, val in lora_sd.items():
- # if "lora_B" in key:
- # lora_rank = val.shape[1]
- # break
- # log.info(f"Loading rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
- # adapter_name = l['path'].split("/")[-1].split(".")[0]
- # adapter_weight = l['strength']
- # pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
-
- # #transformer = load_lora_into_transformer(lora, transformer)
- # adapter_list.append(adapter_name)
- # adapter_weights.append(adapter_weight)
- # for l in lora:
- # pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
- # #pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
-
pipe.transformer = mz_gguf_loader.quantize_load_state_dict(pipe.transformer, sd, device="cpu")
+ del sd
+
if load_device == "offload_device":
pipe.transformer.to(offload_device)
else:
pipe.transformer.to(device)
-
pipeline = {
"pipe": pipe,
"dtype": vae_dtype,
@@ -629,9 +579,253 @@ class DownloadAndLoadCogVideoGGUFModel:
"manual_offloading": True,
}
+ return (pipeline, vae)
+
+#region ModelLoader
+class CogVideoXModelLoader:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}),
+
+ "base_precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
+ "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "optional quantization method"}),
+ "load_device": (["main_device", "offload_device"], {"default": "main_device"}),
+ "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
+ },
+ "optional": {
+ "block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
+ "lora": ("COGLORA", {"default": None}),
+ "compile_args":("COMPILEARGS", ),
+ "attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
+ }
+ }
+
+ RETURN_TYPES = ("COGVIDEOMODEL",)
+ RETURN_NAMES = ("model", )
+ FUNCTION = "loadmodel"
+ CATEGORY = "CogVideoWrapper"
+
+ def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
+ block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
+
+ device = mm.get_torch_device()
+ offload_device = mm.unet_offload_device()
+ manual_offloading = True
+ transformer_load_device = device if load_device == "main_device" else offload_device
+ mm.soft_empty_cache()
+
+ base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[base_precision]
+
+ model_path = folder_paths.get_full_path_or_raise("diffusion_models", model)
+ sd = load_torch_file(model_path, device=transformer_load_device)
+
+ model_type = ""
+ if sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2):
+ model_type = "fun_5b"
+ elif sd["patch_embed.proj.weight"].shape == (3072, 16, 2, 2):
+ model_type = "5b"
+ elif sd["patch_embed.proj.weight"].shape == (3072, 128):
+ model_type = "5b_1_5"
+ elif sd["patch_embed.proj.weight"].shape == (3072, 256):
+ model_type = "5b_I2V_1_5"
+ elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2):
+ model_type = "fun_2b"
+ elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2):
+ model_type = "2b"
+ elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2):
+ if "pos_embedding" in sd:
+ model_type = "fun_5b_pose"
+ else:
+ model_type = "I2V_5b"
+ else:
+ raise Exception("Selected model is not recognized")
+ log.info(f"Detected CogVideoX model type: {model_type}")
+
+ if "5b" in model_type:
+ scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json')
+ transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_5b.json')
+ elif "2b" in model_type:
+ scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json')
+ transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_2b.json')
+
+ with open(transformer_config_path) as f:
+ transformer_config = json.load(f)
+
+ with init_empty_weights():
+ if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]:
+ transformer_config["in_channels"] = 32
+ if "1_5" in model_type:
+ transformer_config["ofs_embed_dim"] = 512
+ transformer_config["use_learned_positional_embeddings"] = False
+ transformer_config["patch_size_t"] = 2
+ transformer_config["patch_bias"] = False
+ transformer_config["sample_height"] = 300
+ transformer_config["sample_width"] = 300
+ elif "fun" in model_type:
+ transformer_config["in_channels"] = 33
+ else:
+ if "1_5" in model_type:
+ transformer_config["use_learned_positional_embeddings"] = False
+ transformer_config["patch_size_t"] = 2
+ transformer_config["patch_bias"] = False
+ #transformer_config["sample_height"] = 300 todo: check if this is needed
+ #transformer_config["sample_width"] = 300
+ transformer_config["in_channels"] = 16
+
+ transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
+
+ #load weights
+ #params_to_keep = {}
+ log.info("Using accelerate to load and assign model weights to device...")
+
+ for name, param in transformer.named_parameters():
+ #dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
+ set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name])
+ del sd
+
+
+ #scheduler
+ with open(scheduler_config_path) as f:
+ scheduler_config = json.load(f)
+ scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config, subfolder="scheduler")
+
+ if block_edit is not None:
+ transformer = remove_specific_blocks(transformer, block_edit)
+
+ if "fused" in attention_mode:
+ from diffusers.models.attention import Attention
+ transformer.fuse_qkv_projections = True
+ for module in transformer.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+ transformer.attention_mode = attention_mode
+
+ if "fun" in model_type:
+ if not "pose" in model_type:
+ raise NotImplementedError("Fun models besides pose are not supported with this loader yet")
+ pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
+ else:
+ pipe = CogVideoXPipeline(transformer, scheduler, dtype=base_dtype)
+ else:
+ pipe = CogVideoXPipeline(transformer, scheduler, dtype=base_dtype)
+
+ if enable_sequential_cpu_offload:
+ pipe.enable_sequential_cpu_offload()
+
+ #LoRAs
+ if lora is not None:
+ from .lora_utils import merge_lora#, load_lora_into_transformer
+ if "fun" in model.lower():
+ for l in lora:
+ log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
+ transformer = merge_lora(transformer, l["path"], l["strength"])
+ else:
+ adapter_list = []
+ adapter_weights = []
+ for l in lora:
+ fuse = True if l["fuse_lora"] else False
+ lora_sd = load_torch_file(l["path"])
+ for key, val in lora_sd.items():
+ if "lora_B" in key:
+ lora_rank = val.shape[1]
+ break
+ log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
+ adapter_name = l['path'].split("/")[-1].split(".")[0]
+ adapter_weight = l['strength']
+ pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
+
+ #transformer = load_lora_into_transformer(lora, transformer)
+ adapter_list.append(adapter_name)
+ adapter_weights.append(adapter_weight)
+ for l in lora:
+ pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
+ if fuse:
+ lora_scale = 1
+ dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
+ if any(item in lora[-1]["path"].lower() for item in dimension_loras):
+ lora_scale = lora_scale / lora_rank
+ pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
+
+ if compile_args is not None:
+ pipe.transformer.to(memory_format=torch.channels_last)
+
+ #quantization
+ if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast":
+ params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
+ if "1.5" in model:
+ params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"})
+ for name, param in pipe.transformer.named_parameters():
+ if not any(keyword in name for keyword in params_to_keep):
+ param.data = param.data.to(torch.float8_e4m3fn)
+
+ if quantization == "fp8_e4m3fn_fast":
+ from .fp8_optimization import convert_fp8_linear
+ if "1.5" in model:
+ params_to_keep.update({"ff"}) #otherwise NaNs
+ convert_fp8_linear(pipe.transformer, base_dtype, params_to_keep=params_to_keep)
+
+ #compile
+ if compile_args is not None:
+ torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
+ for i, block in enumerate(pipe.transformer.transformer_blocks):
+ if "CogVideoXBlock" in str(block):
+ pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
+
+ if "torchao" in quantization:
+ try:
+ from torchao.quantization import (
+ quantize_,
+ fpx_weight_only,
+ float8_dynamic_activation_float8_weight,
+ int8_dynamic_activation_int8_weight
+ )
+ except:
+ raise ImportError("torchao is not installed, please install torchao to use fp8dq")
+
+ def filter_fn(module: nn.Module, fqn: str) -> bool:
+ target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models
+ if any(sub in fqn for sub in target_submodules):
+ return isinstance(module, nn.Linear)
+ return False
+
+ if "fp6" in quantization: #slower for some reason on 4090
+ quant_func = fpx_weight_only(3, 2)
+ elif "fp8dq" in quantization: #very fast on 4090 when compiled
+ quant_func = float8_dynamic_activation_float8_weight()
+ elif 'fp8dqrow' in quantization:
+ from torchao.quantization.quant_api import PerRow
+ quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow())
+ elif 'int8dq' in quantization:
+ quant_func = int8_dynamic_activation_int8_weight()
+
+ for i, block in enumerate(pipe.transformer.transformer_blocks):
+ if "CogVideoXBlock" in str(block):
+ quantize_(block, quant_func, filter_fn=filter_fn)
+
+ manual_offloading = False # to disable manual .to(device) calls
+ log.info(f"Quantized transformer blocks to {quantization}")
+
+ # if load_device == "offload_device":
+ # pipe.transformer.to(offload_device)
+ # else:
+ # pipe.transformer.to(device)
+
+ pipeline = {
+ "pipe": pipe,
+ "dtype": base_dtype,
+ "base_path": model,
+ "onediff": False,
+ "cpu_offloading": enable_sequential_cpu_offload,
+ "scheduler_config": scheduler_config,
+ "model_name": model,
+ "manual_offloading": manual_offloading,
+ }
+
return (pipeline,)
-#revion VAE
+#region VAE
class CogVideoXVAELoader:
@classmethod
@@ -829,6 +1023,7 @@ NODE_CLASS_MAPPINGS = {
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
"CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoXVAELoader": CogVideoXVAELoader,
+ "CogVideoXModelLoader": CogVideoXModelLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@@ -837,4 +1032,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadToraModel": "(Down)load Tora Model",
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoXVAELoader": "CogVideoX VAE Loader",
+ "CogVideoXModelLoader": "CogVideoX Model Loader",
}
\ No newline at end of file
diff --git a/nodes.py b/nodes.py
index b18a978..bd89609 100644
--- a/nodes.py
+++ b/nodes.py
@@ -1,7 +1,6 @@
import os
import torch
-import folder_paths
-import comfy.model_management as mm
+import json
from einops import rearrange
from contextlib import nullcontext
@@ -38,11 +37,10 @@ scheduler_mapping = {
}
available_schedulers = list(scheduler_mapping.keys())
-from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil
+from diffusers.video_processor import VideoProcessor
-from PIL import Image
-import numpy as np
-import json
+import folder_paths
+import comfy.model_management as mm
script_directory = os.path.dirname(os.path.abspath(__file__))
@@ -129,94 +127,6 @@ class CogVideoXTorchCompileSettings:
return (compile_args, )
#region TextEncode
-class CogVideoEncodePrompt:
- @classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "pipeline": ("COGVIDEOPIPE",),
- "prompt": ("STRING", {"default": "", "multiline": True} ),
- "negative_prompt": ("STRING", {"default": "", "multiline": True} ),
- }
- }
-
- RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
- RETURN_NAMES = ("positive", "negative")
- FUNCTION = "process"
- CATEGORY = "CogVideoWrapper"
-
- def process(self, pipeline, prompt, negative_prompt):
- device = mm.get_torch_device()
- offload_device = mm.unet_offload_device()
- pipe = pipeline["pipe"]
- dtype = pipeline["dtype"]
-
- pipe.text_encoder.to(device)
- pipe.transformer.to(offload_device)
-
- positive, negative = pipe.encode_prompt(
- prompt=prompt,
- negative_prompt=negative_prompt,
- do_classifier_free_guidance=True,
- num_videos_per_prompt=1,
- max_sequence_length=226,
- device=device,
- dtype=dtype,
- )
- pipe.text_encoder.to(offload_device)
-
- return (positive, negative)
-
-# Inject clip_l and t5xxl w/ individual strength adjustments for ComfyUI's DualCLIPLoader node for CogVideoX. Use CLIPSave node from any SDXL model then load in a custom clip_l model.
-# For some reason seems to give a lot more movement and consistency on new CogVideoXFun img2vid? set 'type' to flux / DualClipLoader.
-class CogVideoDualTextEncode_311:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "clip": ("CLIP",),
- "clip_l": ("STRING", {"default": "", "multiline": True}),
- "t5xxl": ("STRING", {"default": "", "multiline": True}),
- "clip_l_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), # excessive max for testing, have found intesting results up to 20 max?
- "t5xxl_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), # setting this to 0.0001 or level as high as 18 seems to work.
- }
- }
-
- RETURN_TYPES = ("CONDITIONING",)
- RETURN_NAMES = ("conditioning",)
- FUNCTION = "process"
- CATEGORY = "CogVideoWrapper"
-
- def process(self, clip, clip_l, t5xxl, clip_l_strength, t5xxl_strength):
- load_device = mm.text_encoder_device()
- offload_device = mm.text_encoder_offload_device()
-
- # setup tokenizer for clip_l and t5xxl
- clip.tokenizer.t5xxl.pad_to_max_length = True
- clip.tokenizer.t5xxl.max_length = 226
- clip.cond_stage_model.to(load_device)
-
- # tokenize clip_l and t5xxl
- tokens_l = clip.tokenize(clip_l, return_word_ids=True)
- tokens_t5 = clip.tokenize(t5xxl, return_word_ids=True)
-
- # encode the tokens separately
- embeds_l = clip.encode_from_tokens(tokens_l, return_pooled=False, return_dict=False)
- embeds_t5 = clip.encode_from_tokens(tokens_t5, return_pooled=False, return_dict=False)
-
- # apply strength adjustments to each embedding
- if embeds_l.dim() == 3:
- embeds_l *= clip_l_strength
- if embeds_t5.dim() == 3:
- embeds_t5 *= t5xxl_strength
-
- # combine the embeddings by summing them
- combined_embeds = embeds_l + embeds_t5
-
- # offload the model to save memory
- clip.cond_stage_model.to(offload_device)
-
- return (combined_embeds,)
-
class CogVideoTextEncode:
@classmethod
def INPUT_TYPES(s):
@@ -285,20 +195,31 @@ class CogVideoTextEncodeCombine:
return (embeds, )
-#region ImageEncode
+#region ImageEncode
+
+def add_noise_to_reference_video(image, ratio=None):
+ if ratio is None:
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
+ sigma = torch.exp(sigma).to(image.dtype)
+ else:
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
+
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
+ image = image + image_noise
+ return image
+
class CogVideoImageEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
- "pipeline": ("COGVIDEOPIPE",),
- "image": ("IMAGE", ),
+ "vae": ("VAE",),
+ "start_image": ("IMAGE", ),
},
"optional": {
- "chunk_size": ("INT", {"default": 16, "min": 4, "tooltip": "How many images to encode at once, lower values use less memory"}),
+ "end_image": ("IMAGE", ),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
- "mask": ("MASK", ),
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
- "vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
},
}
@@ -307,49 +228,111 @@ class CogVideoImageEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
- def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0, vae_override=None):
+ def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
-
- B, H, W, C = image.shape
-
- vae = pipeline["pipe"].vae if vae_override is None else vae_override
- vae.enable_slicing()
- model_name = pipeline.get("model_name", "")
-
- if ("1.5" in model_name or "1_5" in model_name) and image.shape[0] == 1:
- vae_scaling_factor = 1 #/ vae.config.scaling_factor
- else:
- vae_scaling_factor = vae.config.scaling_factor
+
+ try:
+ vae.enable_slicing()
+ except:
+ pass
+
+ vae_scaling_factor = vae.config.scaling_factor
if enable_tiling:
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
enable_vae_encode_tiling(vae)
- if not pipeline["cpu_offloading"]:
- vae.to(device)
+ vae.to(device)
- check_diffusers_version()
try:
vae._clear_fake_context_parallel_cache()
except:
pass
- input_image = image.clone()
- if mask is not None:
- pipeline["pipe"].original_mask = mask
- # print(mask.shape)
- # mask = mask.repeat(B, 1, 1) # Shape: [B, H, W]
- # mask = mask.unsqueeze(-1).repeat(1, 1, 1, C)
- # print(mask.shape)
- # input_image = input_image * (1 -mask)
- else:
- pipeline["pipe"].original_mask = None
- #input_image = input_image.permute(0, 3, 1, 2) # B, C, H, W
- #input_image = pipeline["pipe"].video_processor.preprocess(input_image).to(device, dtype=vae.dtype)
- #input_image = input_image.unsqueeze(2)
+ if noise_aug_strength > 0:
+ start_image = add_noise_to_reference_video(start_image, ratio=noise_aug_strength)
+ if end_image is not None:
+ end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength)
+
+ latents_list = []
+ start_image = (start_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
+ start_latents = vae.encode(start_image).latent_dist.sample(generator)
+ start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
+
+ if end_image is not None:
+ end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
+ end_latents = vae.encode(end_image).latent_dist.sample(generator)
+ end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
+ latents_list = [start_latents, end_latents]
+ final_latents = torch.cat(latents_list, dim=1)
+ else:
+ final_latents = start_latents
+
+ final_latents = final_latents * vae_scaling_factor
+
+ log.info(f"Encoded latents shape: {final_latents.shape}")
+ vae.to(offload_device)
+
+ return ({"samples": final_latents}, )
+
+class CogVideoImageEncodeFunInP:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "vae": ("VAE",),
+ "start_image": ("IMAGE", ),
+ "num_frames": ("INT", {"default": 49, "min": 2, "max": 1024, "step": 1}),
+ },
+ "optional": {
+ "end_image": ("IMAGE", ),
+ "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
+ "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
+ },
+ }
+
+ RETURN_TYPES = ("LATENT",)
+ RETURN_NAMES = ("image_cond_latents",)
+ FUNCTION = "encode"
+ CATEGORY = "CogVideoWrapper"
+
+ def encode(self, vae, start_image, num_frames, end_image=None, enable_tiling=False, noise_aug_strength=0.0):
+
+ device = mm.get_torch_device()
+ offload_device = mm.unet_offload_device()
+ generator = torch.Generator(device=device).manual_seed(0)
+
+ try:
+ vae.enable_slicing()
+ except:
+ pass
+
+ vae_scaling_factor = vae.config.scaling_factor
+
+ if enable_tiling:
+ from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
+ enable_vae_encode_tiling(vae)
+
+ vae.to(device)
+
+ try:
+ vae._clear_fake_context_parallel_cache()
+ except:
+ pass
+
+ if end_image is not None:
+ # Create a tensor of zeros for padding
+ padding = torch.zeros((num_frames - 2, start_image.shape[1], start_image.shape[2], 3), device=end_image.device, dtype=end_image.dtype) * -1
+ # Concatenate start_image, padding, and end_image
+ input_image = torch.cat([start_image, padding, end_image], dim=0)
+ else:
+ # Create a tensor of zeros for padding
+ padding = torch.zeros((num_frames - 1, start_image.shape[1], start_image.shape[2], 3), device=start_image.device, dtype=start_image.dtype) * -1
+ # Concatenate start_image and padding
+ input_image = torch.cat([start_image, padding], dim=0)
+
input_image = input_image * 2.0 - 1.0
input_image = input_image.to(vae.dtype).to(device)
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
@@ -357,120 +340,34 @@ class CogVideoImageEncode:
B, C, T, H, W = input_image.shape
if noise_aug_strength > 0:
input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength)
+
+ bs = 1
+ new_mask_pixel_values = []
+ print("input_image shape: ",input_image.shape)
+ for i in range(0, input_image.shape[0], bs):
+ mask_pixel_values_bs = input_image[i : i + bs]
+ mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0]
+ print("mask_pixel_values_bs: ",mask_pixel_values_bs.parameters.shape)
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
+ print("mask_pixel_values_bs: ",mask_pixel_values_bs.shape, mask_pixel_values_bs.min(), mask_pixel_values_bs.max())
+ new_mask_pixel_values.append(mask_pixel_values_bs)
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+ masked_image_latents = masked_image_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
- latents_list = []
- # Loop through the temporal dimension in chunks of 16
- for i in range(0, T, chunk_size):
- # Get the chunk of 16 frames (or remaining frames if less than 16 are left)
- end_index = min(i + chunk_size, T)
- image_chunk = input_image[:, :, i:end_index, :, :] # Shape: [B, C, chunk_size, H, W]
+ mask = torch.zeros_like(masked_image_latents[:, :, :1, :, :])
+ if end_image is not None:
+ mask[:, -1, :, :, :] = vae_scaling_factor
+ mask[:, 0, :, :, :] = vae_scaling_factor
- # Encode the chunk of images
- latents = vae.encode(image_chunk)
-
- sample_mode = "sample"
- if hasattr(latents, "latent_dist") and sample_mode == "sample":
- latents = latents.latent_dist.sample(generator)
- elif hasattr(latents, "latent_dist") and sample_mode == "argmax":
- latents = latents.latent_dist.mode()
- elif hasattr(latents, "latents"):
- latents = latents.latents
-
- latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
- latents_list.append(latents)
-
- # Concatenate all the chunks along the temporal dimension
- final_latents = torch.cat(latents_list, dim=1)
- final_latents = final_latents * vae_scaling_factor
+ final_latents = masked_image_latents * vae_scaling_factor
log.info(f"Encoded latents shape: {final_latents.shape}")
- if not pipeline["cpu_offloading"]:
- vae.to(offload_device)
+ vae.to(offload_device)
- return ({"samples": final_latents}, )
-
-class CogVideoImageInterpolationEncode:
- @classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "pipeline": ("COGVIDEOPIPE",),
- "start_image": ("IMAGE", ),
- "end_image": ("IMAGE", ),
- },
- "optional": {
- "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
- "mask": ("MASK", ),
- "vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
-
- },
- }
-
- RETURN_TYPES = ("LATENT",)
- RETURN_NAMES = ("samples",)
- FUNCTION = "encode"
- CATEGORY = "CogVideoWrapper"
-
- def encode(self, pipeline, start_image, end_image, enable_tiling=False, mask=None, vae_override=None):
- device = mm.get_torch_device()
- offload_device = mm.unet_offload_device()
- generator = torch.Generator(device=device).manual_seed(0)
-
- B, H, W, C = start_image.shape
-
- vae = pipeline["pipe"].vae if vae_override is None else vae_override
- vae.enable_slicing()
- model_name = pipeline.get("model_name", "")
-
- if ("1.5" in model_name or "1_5" in model_name):
- vae_scaling_factor = 1 / vae.config.scaling_factor
- else:
- vae_scaling_factor = vae.config.scaling_factor
- vae.enable_slicing()
-
- if enable_tiling:
- from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
- enable_vae_encode_tiling(vae)
-
- if not pipeline["cpu_offloading"]:
- vae.to(device)
-
- check_diffusers_version()
- try:
- vae._clear_fake_context_parallel_cache()
- except:
- pass
-
- if mask is not None:
- pipeline["pipe"].original_mask = mask
- # print(mask.shape)
- # mask = mask.repeat(B, 1, 1) # Shape: [B, H, W]
- # mask = mask.unsqueeze(-1).repeat(1, 1, 1, C)
- # print(mask.shape)
- # input_image = input_image * (1 -mask)
- else:
- pipeline["pipe"].original_mask = None
-
- start_image = (start_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
- end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
- B, T, C, H, W = start_image.shape
-
- latents_list = []
-
- # Encode the chunk of images
- start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae_scaling_factor
- end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae_scaling_factor
-
- start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
- end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
- latents_list = [start_latents, end_latents]
-
- # Concatenate all the chunks along the temporal dimension
- final_latents = torch.cat(latents_list, dim=1)
- log.info(f"Encoded latents shape: {final_latents.shape}")
- if not pipeline["cpu_offloading"]:
- vae.to(offload_device)
-
- return ({"samples": final_latents}, )
+ return ({
+ "samples": final_latents,
+ "mask": mask
+ },)
#region Tora
from .tora.traj_utils import process_traj, scale_traj_list_to_256
@@ -480,8 +377,8 @@ class ToraEncodeTrajectory:
@classmethod
def INPUT_TYPES(s):
return {"required": {
- "pipeline": ("COGVIDEOPIPE",),
"tora_model": ("TORAMODEL",),
+ "vae": ("VAE",),
"coordinates": ("STRING", {"forceInput": True}),
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
@@ -491,7 +388,7 @@ class ToraEncodeTrajectory:
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
"optional": {
- "enable_tiling": ("BOOLEAN", {"default": False}),
+ "enable_tiling": ("BOOLEAN", {"default": True}),
}
}
@@ -500,14 +397,16 @@ class ToraEncodeTrajectory:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
- def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model, enable_tiling=False):
+ def encode(self, vae, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model, enable_tiling=False):
check_diffusers_version()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
- vae = pipeline["pipe"].vae
- vae.enable_slicing()
+ try:
+ vae.enable_slicing()
+ except:
+ pass
try:
vae._clear_fake_context_parallel_cache()
except:
@@ -533,33 +432,26 @@ class ToraEncodeTrajectory:
video_flow, points = process_traj(coords_list, num_frames, (height,width), device=device)
video_flow = rearrange(video_flow, "T H W C -> T C H W")
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
-
-
- video_flow = (
- rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype)
- )
+ video_flow = (rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype))
video_flow_image = rearrange(video_flow, "B C T H W -> (B T) H W C")
- print(video_flow_image.shape)
+ #print(video_flow_image.shape)
mm.soft_empty_cache()
# VAE encode
- if not pipeline["cpu_offloading"]:
- vae.to(device)
-
+ vae.to(device)
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
log.info(f"video_flow shape after encoding: {video_flow.shape}") #torch.Size([1, 16, 4, 80, 80])
+ vae.to(offload_device)
- if not pipeline["cpu_offloading"]:
- vae.to(offload_device)
+ tora_model["traj_extractor"].to(device)
#print("video_flow shape before traj_extractor: ", video_flow.shape) #torch.Size([1, 16, 4, 80, 80])
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
+ tora_model["traj_extractor"].to(offload_device)
video_flow_features = torch.stack(video_flow_features)
#print("video_flow_features after traj_extractor: ", video_flow_features.shape) #torch.Size([42, 4, 128, 40, 40])
video_flow_features = video_flow_features * strength
-
-
tora = {
"video_flow_features" : video_flow_features,
"start_percent" : start_percent,
@@ -574,7 +466,7 @@ class ToraEncodeOpticalFlow:
@classmethod
def INPUT_TYPES(s):
return {"required": {
- "pipeline": ("COGVIDEOPIPE",),
+ "vae": ("VAE",),
"tora_model": ("TORAMODEL",),
"optical_flow": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
@@ -589,15 +481,18 @@ class ToraEncodeOpticalFlow:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
- def encode(self, pipeline, optical_flow, strength, tora_model, start_percent, end_percent):
+ def encode(self, vae, optical_flow, strength, tora_model, start_percent, end_percent):
check_diffusers_version()
B, H, W, C = optical_flow.shape
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
- vae = pipeline["pipe"].vae
- vae.enable_slicing()
+ try:
+ vae.enable_slicing()
+ except:
+ pass
+
try:
vae._clear_fake_context_parallel_cache()
except:
@@ -609,15 +504,14 @@ class ToraEncodeOpticalFlow:
mm.soft_empty_cache()
# VAE encode
- if not pipeline["cpu_offloading"]:
- vae.to(device)
+
+ vae.to(device)
video_flow = video_flow.to(vae.dtype).to(vae.device)
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
vae.to(offload_device)
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
video_flow_features = torch.stack(video_flow_features)
-
video_flow_features = video_flow_features * strength
log.info(f"video_flow shape: {video_flow.shape}")
@@ -632,91 +526,7 @@ class ToraEncodeOpticalFlow:
return (tora, )
-def add_noise_to_reference_video(image, ratio=None):
- if ratio is None:
- sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
- sigma = torch.exp(sigma).to(image.dtype)
- else:
- sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
-
- image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
- image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
- image = image + image_noise
- return image
-class CogVideoControlImageEncode:
- @classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "pipeline": ("COGVIDEOPIPE",),
- "control_video": ("IMAGE", ),
- "base_resolution": ("INT", {"min": 64, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
- "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
- "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
- },
- }
-
- RETURN_TYPES = ("COGCONTROL_LATENTS", "INT", "INT",)
- RETURN_NAMES = ("control_latents", "width", "height")
- FUNCTION = "encode"
- CATEGORY = "CogVideoWrapper"
-
- def encode(self, pipeline, control_video, base_resolution, enable_tiling, noise_aug_strength=0.0563):
- device = mm.get_torch_device()
- offload_device = mm.unet_offload_device()
-
- B, H, W, C = control_video.shape
-
- vae = pipeline["pipe"].vae
- vae.enable_slicing()
-
- if enable_tiling:
- from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
- enable_vae_encode_tiling(vae)
-
- if not pipeline["cpu_offloading"]:
- vae.to(device)
-
- # Count most suitable height and width
- aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
-
- control_video = np.array(control_video.cpu().numpy() * 255, np.uint8)
- original_width, original_height = Image.fromarray(control_video[0]).size
-
- closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
- height, width = [int(x / 16) * 16 for x in closest_size]
- log.info(f"Closest bucket size: {width}x{height}")
-
- video_length = int((B - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if B != 1 else 1
- input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width))
-
- control_video = pipeline["pipe"].image_processor.preprocess(rearrange(input_video, "b c f h w -> (b f) c h w"), height=height, width=width)
- control_video = control_video.to(dtype=torch.float32)
- control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
-
- masked_image = control_video.to(device=device, dtype=vae.dtype)
- if noise_aug_strength > 0:
- masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
- bs = 1
- new_mask_pixel_values = []
- for i in range(0, masked_image.shape[0], bs):
- mask_pixel_values_bs = masked_image[i : i + bs]
- mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0]
- mask_pixel_values_bs = mask_pixel_values_bs.mode()
- new_mask_pixel_values.append(mask_pixel_values_bs)
- masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
- masked_image_latents = masked_image_latents * vae.config.scaling_factor
-
- vae.to(offload_device)
-
- control_latents = {
- "latents": masked_image_latents,
- "num_frames" : B,
- "height" : height,
- "width" : width,
- }
-
- return (control_latents, width, height)
#region FasterCache
class CogVideoXFasterCache:
@@ -757,12 +567,10 @@ class CogVideoSampler:
def INPUT_TYPES(s):
return {
"required": {
- "pipeline": ("COGVIDEOPIPE",),
+ "model": ("COGVIDEOMODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
- "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 16}),
- "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 16}),
- "num_frames": ("INT", {"default": 49, "min": 17, "max": 1024, "step": 4}),
+ "num_frames": ("INT", {"default": 49, "min": 1, "max": 1024, "step": 1}),
"steps": ("INT", {"default": 50, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
@@ -773,45 +581,54 @@ class CogVideoSampler:
},
"optional": {
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
- "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"image_cond_latents": ("LATENT",{"tooltip": "Latent to use for image2video conditioning"} ),
+ "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"context_options": ("COGCONTEXT", ),
"controlnet": ("COGVIDECONTROLNET",),
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ),
- #"sigmas": ("SIGMAS", ),
}
}
- RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",)
- RETURN_NAMES = ("cogvideo_pipe", "samples",)
+ RETURN_TYPES = ("LATENT",)
+ RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
- def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None,
+ def process(self, pipeline, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
mm.soft_empty_cache()
- base_path = pipeline["base_path"]
model_name = pipeline.get("model_name", "")
- supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() else False
+ supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() or "fun" in model_name.lower() else False
- assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
- assert (
- "I2V" not in model_name or
- "1.5" in model_name or
- "1_5" in model_name or
- num_frames == 49 or
- context_options is not None
- ), "1.0 I2V model can only do 49 frames"
+ if "fun" in model_name.lower() and image_cond_latents is not None:
+ assert image_cond_latents["mask"] is not None, "For fun inpaint models use CogVideoImageEncodeFunInP"
+ fun_mask = image_cond_latents["mask"]
+ else:
+ fun_mask = None
+
if image_cond_latents is not None:
assert supports_image_conds, "Image condition latents only supported for I2V and Interpolation models"
- # if "I2V" in model_name:
- # assert image_cond_latents["samples"].shape[1] == 1, "I2V model only supports single image condition latent"
- # elif "interpolation" in model_name.lower():
- # assert image_cond_latents["samples"].shape[1] == 2, "Interpolation model needs two image condition latents"
+ image_conds = image_cond_latents["samples"]
+ if "1.5" in model_name or "1_5" in model_name:
+ image_conds = image_conds / 0.7 # needed for 1.5 models
else:
- assert not supports_image_conds, "Image condition latents required for I2V models"
+ if not "fun" in model_name.lower():
+ assert not supports_image_conds, "Image condition latents required for I2V models"
+ image_conds = None
+
+ if samples is not None:
+ if len(samples["samples"].shape) == 5:
+ B, T, C, H, W = samples["samples"].shape
+ latents = samples["samples"]
+ if len(samples["samples"].shape) == 4:
+ B, C, H, W = samples["samples"].shape
+ latents = None
+ if image_cond_latents is not None:
+ B, T, C, H, W = image_cond_latents["samples"].shape
+ height = H * 8
+ width = W * 8
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@@ -861,9 +678,6 @@ class CogVideoSampler:
cfg = [cfg for _ in range(steps)]
else:
assert len(cfg) == steps, "Length of cfg list must match number of steps"
-
- # if sigmas is not None:
- # sigma_list = sigmas.tolist()
try:
torch.cuda.reset_peak_memory_stats(device)
except:
@@ -878,9 +692,9 @@ class CogVideoSampler:
width = width,
num_frames = num_frames,
guidance_scale=cfg,
- #sigmas=sigma_list if sigmas is not None else None,
- latents=samples["samples"] if samples is not None else None,
- image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None,
+ latents=latents if samples is not None else None,
+ fun_mask = fun_mask,
+ image_cond_latents=image_conds,
denoise_strength=denoise_strength,
prompt_embeds=positive.to(dtype).to(device),
negative_prompt_embeds=negative.to(dtype).to(device),
@@ -910,7 +724,11 @@ class CogVideoSampler:
except:
pass
- return (pipeline, {"samples": latents})
+ additional_frames = getattr(pipe, "additional_frames", 0)
+ return ({
+ "samples": latents,
+ "additional_frames": additional_frames,
+ },)
class CogVideoControlNet:
@classmethod
@@ -930,13 +748,7 @@ class CogVideoControlNet:
CATEGORY = "CogVideoWrapper"
def encode(self, controlnet, images, control_strength, control_start_percent, control_end_percent):
- device = mm.get_torch_device()
- offload_device = mm.unet_offload_device()
-
- B, H, W, C = images.shape
-
control_frames = images.permute(0, 3, 1, 2).unsqueeze(0) * 2 - 1
-
controlnet = {
"control_model": controlnet,
"control_frames": control_frames,
@@ -944,7 +756,6 @@ class CogVideoControlNet:
"control_start": control_start_percent,
"control_end": control_end_percent,
}
-
return (controlnet,)
#region VideoDecode
@@ -952,8 +763,8 @@ class CogVideoDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
- "pipeline": ("COGVIDEOPIPE",),
"samples": ("LATENT", ),
+ "vae": ("VAE", {"default": None}),
"enable_vae_tiling": ("BOOLEAN", {"default": True, "tooltip": "Drastically reduces memory use but may introduce seams"}),
},
"optional": {
@@ -962,7 +773,6 @@ class CogVideoDecode:
"tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
- "vae_override": ("VAE", {"default": None}),
}
}
@@ -971,19 +781,20 @@ class CogVideoDecode:
FUNCTION = "decode"
CATEGORY = "CogVideoWrapper"
- def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width,
- auto_tile_size=True, vae_override=None):
+ def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width,
+ auto_tile_size=True, pipeline=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
latents = samples["samples"]
- vae = pipeline["pipe"].vae if vae_override is None else vae_override
+
+ additional_frames = samples.get("additional_frames", 0)
- additional_frames = getattr(pipeline["pipe"], "additional_frames", 0)
+ try:
+ vae.enable_slicing()
+ except:
+ pass
- vae.enable_slicing()
-
- if not pipeline["cpu_offloading"]:
- vae.to(device)
+ vae.to(device)
if enable_vae_tiling:
if auto_tile_size:
vae.enable_tiling()
@@ -999,11 +810,11 @@ class CogVideoDecode:
latents = latents.to(vae.dtype).to(device)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / vae.config.scaling_factor * latents
+
try:
vae._clear_fake_context_parallel_cache()
except:
pass
-
try:
frames = vae.decode(latents[:, :, additional_frames:]).sample
except:
@@ -1013,11 +824,13 @@ class CogVideoDecode:
frames = vae.decode(latents[:, :, additional_frames:]).sample
vae.disable_tiling()
- if not pipeline["cpu_offloading"]:
- vae.to(offload_device)
+ vae.to(offload_device)
mm.soft_empty_cache()
- video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
+ video_processor = VideoProcessor(vae_scale_factor=8)
+ video_processor.config.do_resize = False
+
+ video = video_processor.postprocess_video(video=frames, output_type="pt")
video = video[0].permute(0, 2, 3, 1).cpu().float()
return (video,)
@@ -1041,6 +854,7 @@ class CogVideoXFunResizeToClosestBucket:
def resize(self, images, base_resolution, upscale_method, crop):
from comfy.utils import common_upscale
+ from .cogvideox_fun.utils import ASPECT_RATIO_512, get_closest_ratio
B, H, W, C = images.shape
# Count most suitable height and width
@@ -1056,256 +870,6 @@ class CogVideoXFunResizeToClosestBucket:
return (resized_images, width, height)
-#region FunSamplers
-class CogVideoXFunSampler:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "pipeline": ("COGVIDEOPIPE",),
- "positive": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "video_length": ("INT", {"default": 49, "min": 5, "max": 2048, "step": 4}),
- "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
- "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
- "seed": ("INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}),
- "steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}),
- "cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
- "scheduler": (available_schedulers, {"default": 'DDIM'})
- },
- "optional":{
- "start_img": ("IMAGE",),
- "end_img": ("IMAGE",),
- "noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
- "context_options": ("COGCONTEXT", ),
- "tora_trajectory": ("TORAFEATURES", ),
- "fastercache": ("FASTERCACHEARGS",),
- "vid2vid_images": ("IMAGE",),
- "vid2vid_denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
- },
- }
-
- RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",)
- RETURN_NAMES = ("cogvideo_pipe", "samples",)
- FUNCTION = "process"
- CATEGORY = "CogVideoWrapper"
-
- def process(self, pipeline, positive, negative, video_length, width, height, seed, steps, cfg, scheduler,
- start_img=None, end_img=None, noise_aug_strength=0.0563, context_options=None, fastercache=None,
- tora_trajectory=None, vid2vid_images=None, vid2vid_denoise=1.0):
- device = mm.get_torch_device()
- offload_device = mm.unet_offload_device()
- pipe = pipeline["pipe"]
- dtype = pipeline["dtype"]
- base_path = pipeline["base_path"]
- assert "fun" in base_path.lower(), "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'"
- assert "pose" not in base_path.lower(), "'Pose' models not supported in 'CogVideoXFunSampler', use the 'CogVideoXFunControlSampler'"
-
- mm.soft_empty_cache()
-
- #vid2vid
- if vid2vid_images is not None:
- validation_video = np.array(vid2vid_images.cpu().numpy() * 255, np.uint8)
- #img2vid
- elif start_img is not None:
- start_img = [to_pil(_start_img) for _start_img in start_img] if start_img is not None else None
- end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
-
- # Load Sampler
- scheduler_config = pipeline["scheduler_config"]
- if scheduler in scheduler_mapping:
- noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
- pipe.scheduler = noise_scheduler
- else:
- raise ValueError(f"Unknown scheduler: {scheduler}")
-
- if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
- pipe.transformer.to(device)
-
- if context_options is not None:
- context_frames = context_options["context_frames"] // 4
- context_stride = context_options["context_stride"] // 4
- context_overlap = context_options["context_overlap"] // 4
- else:
- context_frames, context_stride, context_overlap = None, None, None
-
- if tora_trajectory is not None:
- pipe.transformer.fuser_list = tora_trajectory["fuser_list"]
-
- if fastercache is not None:
- pipe.transformer.use_fastercache = True
- pipe.transformer.fastercache_counter = 0
- pipe.transformer.fastercache_start_step = fastercache["start_step"]
- pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
- pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
- pipe.transformer.fastercache_device = fastercache["cache_device"]
- pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"]
- log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}")
- else:
- pipe.transformer.use_fastercache = False
- pipe.transformer.fastercache_counter = 0
-
- generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
-
- autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
- autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
- with autocast_context:
- video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
- if vid2vid_images is not None:
- input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=(height, width))
- else:
- input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width))
-
- common_params = {
- "prompt_embeds": positive.to(dtype).to(device),
- "negative_prompt_embeds": negative.to(dtype).to(device),
- "num_frames": video_length,
- "height": height,
- "width": width,
- "generator": generator,
- "guidance_scale": cfg,
- "num_inference_steps": steps,
- "comfyui_progressbar": True,
- "context_schedule":context_options["context_schedule"] if context_options is not None else None,
- "context_frames":context_frames,
- "context_stride": context_stride,
- "context_overlap": context_overlap,
- "freenoise":context_options["freenoise"] if context_options is not None else None,
- "tora":tora_trajectory if tora_trajectory is not None else None,
- }
- latents = pipe(
- **common_params,
- video = input_video,
- mask_video = input_video_mask,
- noise_aug_strength = noise_aug_strength,
- strength = vid2vid_denoise,
- )
- if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
- pipe.transformer.to(offload_device)
- #clear FasterCache
- if fastercache is not None:
- for block in pipe.transformer.transformer_blocks:
- if (hasattr, block, "cached_hidden_states") and block.cached_hidden_states is not None:
- block.cached_hidden_states = None
- block.cached_encoder_hidden_states = None
-
- mm.soft_empty_cache()
-
- return (pipeline, {"samples": latents})
-
-class CogVideoXFunVid2VidSampler:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "note": ("STRING", {"default": "This node is deprecated, functionality moved to 'CogVideoXFunSampler' node instead.", "multiline": True}),
- },
- }
-
- RETURN_TYPES = ()
- FUNCTION = "process"
- CATEGORY = "CogVideoWrapper"
- DEPRECATED = True
- def process(self):
- return ()
-
-class CogVideoXFunControlSampler:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "pipeline": ("COGVIDEOPIPE",),
- "positive": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "control_latents": ("COGCONTROL_LATENTS",),
- "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
- "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}),
- "cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
- "scheduler": (available_schedulers, {"default": 'DDIM'}),
- "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
- "control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
- "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
- },
- "optional": {
- "samples": ("LATENT", ),
- "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
- "context_options": ("COGCONTEXT", ),
- },
- }
-
- RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",)
- RETURN_NAMES = ("cogvideo_pipe", "samples",)
- FUNCTION = "process"
- CATEGORY = "CogVideoWrapper"
-
- def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
- control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0,
- samples=None, denoise_strength=1.0, context_options=None):
- device = mm.get_torch_device()
- offload_device = mm.unet_offload_device()
- pipe = pipeline["pipe"]
- dtype = pipeline["dtype"]
- base_path = pipeline["base_path"]
-
- assert "fun" in base_path.lower(), "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'"
-
- if not pipeline["cpu_offloading"]:
- pipe.enable_model_cpu_offload(device=device)
-
- mm.soft_empty_cache()
-
- if context_options is not None:
- context_frames = context_options["context_frames"] // 4
- context_stride = context_options["context_stride"] // 4
- context_overlap = context_options["context_overlap"] // 4
- else:
- context_frames, context_stride, context_overlap = None, None, None
-
- # Load Sampler
- scheduler_config = pipeline["scheduler_config"]
- if scheduler in scheduler_mapping:
- noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
- pipe.scheduler = noise_scheduler
- else:
- raise ValueError(f"Unknown scheduler: {scheduler}")
-
- generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
-
- autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
- autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
- with autocast_context:
-
- common_params = {
- "prompt_embeds": positive.to(dtype).to(device),
- "negative_prompt_embeds": negative.to(dtype).to(device),
- "num_frames": control_latents["num_frames"],
- "height": control_latents["height"],
- "width": control_latents["width"],
- "generator": generator,
- "guidance_scale": cfg,
- "num_inference_steps": steps,
- "comfyui_progressbar": True,
- }
-
- latents = pipe(
- **common_params,
- control_video=control_latents["latents"],
- control_strength=control_strength,
- control_start_percent=control_start_percent,
- control_end_percent=control_end_percent,
- scheduler_name=scheduler,
- latents=samples["samples"] if samples is not None else None,
- denoise_strength=denoise_strength,
- context_schedule=context_options["context_schedule"] if context_options is not None else None,
- context_frames=context_frames,
- context_stride= context_stride,
- context_overlap= context_overlap,
- freenoise=context_options["freenoise"] if context_options is not None else None
-
- )
-
- return (pipeline, {"samples": latents})
-
class CogVideoLatentPreview:
@classmethod
def INPUT_TYPES(s):
@@ -1332,9 +896,6 @@ class CogVideoLatentPreview:
latents = samples["samples"].clone()
print("in sample", latents.shape)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
-
- device = mm.get_torch_device()
- offload_device = mm.unet_offload_device()
#[[0.0658900170023352, 0.04687556512203313, -0.056971557475649186], [-0.01265770449940036, -0.02814809569100843, -0.0768912512529372], [0.061456544746314665, 0.0005511617552452358, -0.0652574975291287], [-0.09020669168815276, -0.004755440180558637, -0.023763970904494294], [0.031766964513999865, -0.030959599938418375, 0.08654669098083616], [-0.005981764690055846, -0.08809119252349802, -0.06439852368217663], [-0.0212114426433989, 0.08894281999597677, 0.05155629477559985], [-0.013947446911030725, -0.08987475069900677, -0.08923124751217484], [-0.08235967967978511, 0.07268025379974379, 0.08830486164536037], [-0.08052049179735378, -0.050116143175332195, 0.02023752569687405], [-0.07607527759162447, 0.06827156419895981, 0.08678111754261035], [-0.04689089232553825, 0.017294986041038893, -0.10280492336438908], [-0.06105783150270304, 0.07311850680875913, 0.019995735372550075], [-0.09232589996527711, -0.012869815059053047, -0.04355587834255975], [-0.06679931010802251, 0.018399815879067458, 0.06802404982033876], [-0.013062632927118165, -0.04292991477896661, 0.07476243356192845]]
latent_rgb_factors =[[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
@@ -1374,15 +935,9 @@ NODE_CLASS_MAPPINGS = {
"CogVideoSampler": CogVideoSampler,
"CogVideoDecode": CogVideoDecode,
"CogVideoTextEncode": CogVideoTextEncode,
- "CogVideoDualTextEncode_311": CogVideoDualTextEncode_311,
"CogVideoImageEncode": CogVideoImageEncode,
- "CogVideoImageInterpolationEncode": CogVideoImageInterpolationEncode,
- "CogVideoXFunSampler": CogVideoXFunSampler,
- "CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
- "CogVideoXFunControlSampler": CogVideoXFunControlSampler,
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
"CogVideoTransformerEdit": CogVideoTransformerEdit,
- "CogVideoControlImageEncode": CogVideoControlImageEncode,
"CogVideoContextOptions": CogVideoContextOptions,
"CogVideoControlNet": CogVideoControlNet,
"ToraEncodeTrajectory": ToraEncodeTrajectory,
@@ -1390,21 +945,16 @@ NODE_CLASS_MAPPINGS = {
"CogVideoXFasterCache": CogVideoXFasterCache,
"CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket,
"CogVideoLatentPreview": CogVideoLatentPreview,
- "CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings
+ "CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
+ "CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler",
"CogVideoDecode": "CogVideo Decode",
"CogVideoTextEncode": "CogVideo TextEncode",
- "CogVideoDualTextEncode_311": "CogVideo DualTextEncode",
"CogVideoImageEncode": "CogVideo ImageEncode",
- "CogVideoImageInterpolationEncode": "CogVideo ImageInterpolation Encode",
- "CogVideoXFunSampler": "CogVideoXFun Sampler",
- "CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
- "CogVideoXFunControlSampler": "CogVideoXFun Control Sampler",
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
- "CogVideoControlImageEncode": "CogVideo Control ImageEncode",
"CogVideoContextOptions": "CogVideo Context Options",
"ToraEncodeTrajectory": "Tora Encode Trajectory",
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
@@ -1412,4 +962,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket",
"CogVideoLatentPreview": "CogVideo LatentPreview",
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
+ "CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
}
\ No newline at end of file
diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py
index 472a308..f869ce4 100644
--- a/pipeline_cogvideox.py
+++ b/pipeline_cogvideox.py
@@ -17,15 +17,13 @@ import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
-import torch.nn.functional as F
import math
-from diffusers.models import AutoencoderKLCogVideoX
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
-from diffusers.video_processor import VideoProcessor
+
#from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.loaders import CogVideoXLoraLoaderMixin
@@ -120,15 +118,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
- vae ([`AutoencoderKL`]):
- Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
- text_encoder ([`T5EncoderModel`]):
- Frozen text-encoder. CogVideoX uses
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
- [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
- tokenizer (`T5Tokenizer`):
- Tokenizer of class
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
@@ -140,31 +129,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def __init__(
self,
- vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
- original_mask = None,
+ dtype: torch.dtype = torch.bfloat16,
+ is_fun_inpaint: bool = False,
):
super().__init__()
- self.register_modules(
- vae=vae, transformer=transformer, scheduler=scheduler
- )
- self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
- self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.original_mask = original_mask
- self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
- self.video_processor.config.do_resize = False
+ self.register_modules(transformer=transformer, scheduler=scheduler)
+ self.vae_scale_factor_spatial = 8
+ self.vae_scale_factor_temporal = 4
+ self.vae_latent_channels = 16
+ self.vae_dtype = dtype
+ self.is_fun_inpaint = is_fun_inpaint
self.input_with_padding = True
def prepare_latents(
- self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength,
+ self, batch_size, num_channels_latents, num_frames, height, width, device, generator, timesteps, denoise_strength,
num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None
):
shape = (
@@ -174,14 +157,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
- if isinstance(generator, list) and len(generator) != batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
- )
- noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype)
+
+ noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae_dtype)
if freenoise:
- print("Applying FreeNoise")
+ logger.info("Applying FreeNoise")
# code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
video_length = num_frames // 4
delta = context_size - context_overlap
@@ -221,20 +200,20 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
latent_timestep = timesteps[:1]
- noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
frames_needed = noise.shape[1]
current_frames = latents.shape[1]
if frames_needed > current_frames:
- repeat_factor = frames_needed // current_frames
+ repeat_factor = frames_needed - current_frames
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
- latents = torch.cat((latents, additional_frame), dim=1)
+ latents = torch.cat((additional_frame, latents), dim=1)
+ self.additional_frames = repeat_factor
elif frames_needed < current_frames:
latents = latents[:, :frames_needed, :, :, :]
- latents = self.scheduler.add_noise(latents, noise, latent_timestep)
+ latents = self.scheduler.add_noise(latents, noise.to(device), latent_timestep)
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
- return latents, timesteps, noise
+ return latents, timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -355,10 +334,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
guidance_scale: float = 6,
denoise_strength: float = 1.0,
sigmas: Optional[List[float]] = None,
- num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
+ fun_mask: Optional[torch.Tensor] = None,
image_cond_latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
@@ -398,8 +377,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
- The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -443,7 +420,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
- prompt_embeds = prompt_embeds.to(self.vae.dtype)
+ prompt_embeds = prompt_embeds.to(self.vae_dtype)
# 4. Prepare timesteps
if sigmas is None:
@@ -453,7 +430,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
- latent_channels = self.vae.config.latent_channels
+ latent_channels = self.vae_latent_channels
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
@@ -469,18 +446,12 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
self.additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += self.additional_frames * self.vae_scale_factor_temporal
-
- if self.original_mask is not None:
- image_latents = latents
- original_image_latents = image_latents
-
- latents, timesteps, noise = self.prepare_latents(
- batch_size * num_videos_per_prompt,
+ latents, timesteps = self.prepare_latents(
+ batch_size,
latent_channels,
num_frames,
height,
width,
- self.vae.dtype,
device,
generator,
timesteps,
@@ -491,37 +462,41 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
context_overlap=context_overlap,
freenoise=freenoise,
)
- latents = latents.to(self.vae.dtype)
+ latents = latents.to(self.vae_dtype)
+
+ if self.is_fun_inpaint and fun_mask is None: # For FUN inpaint vid2vid, we need to mask all the latents
+ fun_mask = torch.zeros_like(latents[:, :, :1, :, :], device=latents.device, dtype=latents.dtype)
+ fun_masked_video_latents = torch.zeros_like(latents, device=latents.device, dtype=latents.dtype)
# 5.5.
if image_cond_latents is not None:
- if image_cond_latents.shape[1] > 1:
+ if image_cond_latents.shape[1] == 2:
logger.info("More than one image conditioning frame received, interpolating")
padding_shape = (
- batch_size,
- (latents.shape[1] - 2),
- self.vae.config.latent_channels,
- height // self.vae_scale_factor_spatial,
- width // self.vae_scale_factor_spatial,
+ batch_size,
+ (latents.shape[1] - 2),
+ self.vae_latent_channels,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
)
- latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
if self.transformer.config.patch_size_t is not None:
- first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
- image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
+ first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
+ image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
- else:
+ elif image_cond_latents.shape[1] == 1:
logger.info("Only one image conditioning frame received, img2vid")
if self.input_with_padding:
padding_shape = (
batch_size,
(latents.shape[1] - 1),
- self.vae.config.latent_channels,
+ self.vae_latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
- latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
# Select the first frame along the second dimension
if self.transformer.config.patch_size_t is not None:
@@ -529,22 +504,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
else:
image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)
+ else:
+ logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames")
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
-
- # masks
- if self.original_mask is not None:
- mask = self.original_mask.to(device)
- logger.info(f"self.original_mask: {self.original_mask.shape}")
-
- mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
- if mask.shape[0] != latents.shape[1]:
- mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
- else:
- mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
- logger.info(f"latents: {latents.shape}")
- logger.info(f"mask: {mask.shape}")
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -554,7 +518,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
raise NotImplementedError("Context schedule not currently supported with image conditioning")
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_context_schedule = True
- from .cogvideox_fun.context import get_context_scheduler
+ from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
#todo ofs embeds?
@@ -747,7 +711,18 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if image_cond_latents is not None:
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
- latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
+ if fun_mask is not None: #for fun img2vid and interpolation
+ fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
+ masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2)
+ latent_model_input = torch.cat([latent_model_input, masks_input], dim=2)
+ else:
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
+ else: # for Fun inpaint vid2vid
+ if fun_mask is not None:
+ fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
+ fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents
+ fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
+ latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@@ -767,9 +742,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
return_dict=False,
)[0]
if isinstance(controlnet_states, (tuple, list)):
- controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
+ controlnet_states = [x.to(dtype=self.vae_dtype) for x in controlnet_states]
else:
- controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
+ controlnet_states = controlnet_states.to(dtype=self.vae_dtype)
# predict noise model_output
@@ -796,30 +771,18 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
- latents = self.scheduler.step(noise_pred, t, latents.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
+ latents = self.scheduler.step(noise_pred, t, latents.to(self.vae_dtype), **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
- latents.to(self.vae.dtype),
+ latents.to(self.vae_dtype),
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
- # start diff diff
- if i < len(timesteps) - 1 and self.original_mask is not None:
- noise_timestep = timesteps[i + 1]
- image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
- )
- mask = mask.to(latents)
- ts_from = timesteps[0]
- ts_to = timesteps[-1]
- threshold = (t - ts_to) / (ts_from - ts_to)
- mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
- latents = image_latent * mask + latents * (1 - mask)
- # end diff diff
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/pyproject.toml b/pyproject.toml
index 78b9ed8..1ca05f7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,9 +1,9 @@
[project]
name = "comfyui-cogvideoxwrapper"
description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)"
-version = "1.1.0"
+version = "1.5.0"
license = {file = "LICENSE"}
-dependencies = ["huggingface_hub", "diffusers>=0.30.1", "accelerate>=0.33.0"]
+dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"]
[project.urls]
Repository = "https://github.com/kijai/ComfyUI-CogVideoXWrapper"