From a7646c0d6ff3a147299082a5fc6050af326bde02 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 19 Nov 2024 03:04:22 +0200 Subject: [PATCH] refactor - unify all pipelines into one - unify transformer model into one - separate VAE - add single file model loading --- .gitignore | 3 +- cogvideox_fun/autoencoder_magvit.py | 1296 ------------------- cogvideox_fun/pipeline_cogvideox_control.py | 866 ------------- cogvideox_fun/pipeline_cogvideox_inpaint.py | 1037 --------------- cogvideox_fun/transformer_3d.py | 823 ------------ cogvideox_fun/utils.py | 138 +- cogvideox_fun/context.py => context.py | 0 convert_weight_sat2hf.py | 303 ----- custom_cogvideox_transformer_3d.py | 1 - model_loading.py | 432 +++++-- nodes.py | 879 +++---------- pipeline_cogvideox.py | 157 +-- pyproject.toml | 4 +- 13 files changed, 594 insertions(+), 5345 deletions(-) delete mode 100644 cogvideox_fun/autoencoder_magvit.py delete mode 100644 cogvideox_fun/pipeline_cogvideox_control.py delete mode 100644 cogvideox_fun/pipeline_cogvideox_inpaint.py delete mode 100644 cogvideox_fun/transformer_3d.py rename cogvideox_fun/context.py => context.py (100%) delete mode 100644 convert_weight_sat2hf.py diff --git a/.gitignore b/.gitignore index da24bdf..d75870d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ master_ip logs/ *.DS_Store .idea -*.pt \ No newline at end of file +*.pt +tools/ \ No newline at end of file diff --git a/cogvideox_fun/autoencoder_magvit.py b/cogvideox_fun/autoencoder_magvit.py deleted file mode 100644 index 9c2b906..0000000 --- a/cogvideox_fun/autoencoder_magvit.py +++ /dev/null @@ -1,1296 +0,0 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders.single_file_model import FromOriginalModelMixin -from diffusers.utils import logging -from diffusers.utils.accelerate_utils import apply_forward_hook -from diffusers.models.activations import get_activation -from diffusers.models.downsampling import CogVideoXDownsample3D -from diffusers.models.modeling_outputs import AutoencoderKLOutput -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.upsampling import CogVideoXUpsample3D -from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class CogVideoXSafeConv3d(nn.Conv3d): - r""" - A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 - - # Set to 2GB, suitable for CuDNN - if memory_count > 2: - kernel_size = self.kernel_size[0] - part_num = int(memory_count / 2) + 1 - input_chunks = torch.chunk(input, part_num, dim=2) - - if kernel_size > 1: - input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) - for i in range(1, len(input_chunks)) - ] - - output_chunks = [] - for input_chunk in input_chunks: - output_chunks.append(super().forward(input_chunk)) - output = torch.cat(output_chunks, dim=2) - return output - else: - return super().forward(input) - - -class CogVideoXCausalConv3d(nn.Module): - r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. - - Args: - in_channels (`int`): Number of channels in the input tensor. - out_channels (`int`): Number of output channels produced by the convolution. - kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. - stride (`int`, defaults to `1`): Stride of the convolution. - dilation (`int`, defaults to `1`): Dilation rate of the convolution. - pad_mode (`str`, defaults to `"constant"`): Padding mode. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - stride: int = 1, - dilation: int = 1, - pad_mode: str = "constant", - ): - super().__init__() - - if isinstance(kernel_size, int): - kernel_size = (kernel_size,) * 3 - - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - - self.pad_mode = pad_mode - time_pad = dilation * (time_kernel_size - 1) + (1 - stride) - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 - - self.height_pad = height_pad - self.width_pad = width_pad - self.time_pad = time_pad - self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - - self.temporal_dim = 2 - self.time_kernel_size = time_kernel_size - - stride = (stride, 1, 1) - dilation = (dilation, 1, 1) - self.conv = CogVideoXSafeConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - ) - - self.conv_cache = None - - def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: - kernel_size = self.time_kernel_size - if kernel_size > 1: - cached_inputs = ( - [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - ) - inputs = torch.cat(cached_inputs + [inputs], dim=2) - return inputs - - def _clear_fake_context_parallel_cache(self): - del self.conv_cache - self.conv_cache = None - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = self.fake_context_parallel_forward(inputs) - - self._clear_fake_context_parallel_cache() - # Note: we could move these to the cpu for a lower maximum memory usage but its only a few - # hundred megabytes and so let's not do it for now - self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() - - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - inputs = F.pad(inputs, padding_2d, mode="constant", value=0) - - output = self.conv(inputs) - return output - - -class CogVideoXSpatialNorm3D(nn.Module): - r""" - Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific - to 3D-video like data. - - CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. - - Args: - f_channels (`int`): - The number of channels for input to group normalization layer, and output of the spatial norm layer. - zq_channels (`int`): - The number of channels for the quantized vector as described in the paper. - groups (`int`): - Number of groups to separate the channels into for group normalization. - """ - - def __init__( - self, - f_channels: int, - zq_channels: int, - groups: int = 32, - ): - super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) - self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - if f.shape[2] > 1 and f.shape[2] % 2 == 1: - f_first, f_rest = f[:, :, :1], f[:, :, 1:] - f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] - z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] - z_first = F.interpolate(z_first, size=f_first_size) - z_rest = F.interpolate(z_rest, size=f_rest_size) - zq = torch.cat([z_first, z_rest], dim=2) - else: - zq = F.interpolate(zq, size=f.shape[-3:]) - - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f - - -class CogVideoXResnetBlock3D(nn.Module): - r""" - A 3D ResNet block used in the CogVideoX model. - - Args: - in_channels (`int`): - Number of input channels. - out_channels (`int`, *optional*): - Number of output channels. If None, defaults to `in_channels`. - dropout (`float`, defaults to `0.0`): - Dropout rate. - temb_channels (`int`, defaults to `512`): - Number of time embedding channels. - groups (`int`, defaults to `32`): - Number of groups to separate the channels into for group normalization. - eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - non_linearity (`str`, defaults to `"swish"`): - Activation function to use. - conv_shortcut (bool, defaults to `False`): - Whether or not to use a convolution shortcut. - spatial_norm_dim (`int`, *optional*): - The dimension to use for spatial norm if it is to be used instead of group norm. - pad_mode (str, defaults to `"first"`): - Padding mode. - """ - - def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - eps: float = 1e-6, - non_linearity: str = "swish", - conv_shortcut: bool = False, - spatial_norm_dim: Optional[int] = None, - pad_mode: str = "first", - ): - super().__init__() - - out_channels = out_channels or in_channels - - self.in_channels = in_channels - self.out_channels = out_channels - self.nonlinearity = get_activation(non_linearity) - self.use_conv_shortcut = conv_shortcut - - if spatial_norm_dim is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) - else: - self.norm1 = CogVideoXSpatialNorm3D( - f_channels=in_channels, - zq_channels=spatial_norm_dim, - groups=groups, - ) - self.norm2 = CogVideoXSpatialNorm3D( - f_channels=out_channels, - zq_channels=spatial_norm_dim, - groups=groups, - ) - - self.conv1 = CogVideoXCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - - if temb_channels > 0: - self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) - - self.dropout = nn.Dropout(dropout) - self.conv2 = CogVideoXCausalConv3d( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = CogVideoXCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode - ) - else: - self.conv_shortcut = CogVideoXSafeConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward( - self, - inputs: torch.Tensor, - temb: Optional[torch.Tensor] = None, - zq: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - hidden_states = inputs - - if zq is not None: - hidden_states = self.norm1(hidden_states, zq) - else: - hidden_states = self.norm1(hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) - - if temb is not None: - hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] - - if zq is not None: - hidden_states = self.norm2(hidden_states, zq) - else: - hidden_states = self.norm2(hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.in_channels != self.out_channels: - inputs = self.conv_shortcut(inputs) - - hidden_states = hidden_states + inputs - return hidden_states - - -class CogVideoXDownBlock3D(nn.Module): - r""" - A downsampling block used in the CogVideoX model. - - Args: - in_channels (`int`): - Number of input channels. - out_channels (`int`, *optional*): - Number of output channels. If None, defaults to `in_channels`. - temb_channels (`int`, defaults to `512`): - Number of time embedding channels. - num_layers (`int`, defaults to `1`): - Number of resnet layers. - dropout (`float`, defaults to `0.0`): - Dropout rate. - resnet_eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - resnet_act_fn (`str`, defaults to `"swish"`): - Activation function to use. - resnet_groups (`int`, defaults to `32`): - Number of groups to separate the channels into for group normalization. - add_downsample (`bool`, defaults to `True`): - Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. - compress_time (`bool`, defaults to `False`): - Whether or not to downsample across temporal dimension. - pad_mode (str, defaults to `"first"`): - Padding mode. - """ - - _supports_gradient_checkpointing = True - - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - add_downsample: bool = True, - downsample_padding: int = 0, - compress_time: bool = False, - pad_mode: str = "first", - ): - super().__init__() - - resnets = [] - for i in range(num_layers): - in_channel = in_channels if i == 0 else out_channels - resnets.append( - CogVideoXResnetBlock3D( - in_channels=in_channel, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=resnet_groups, - eps=resnet_eps, - non_linearity=resnet_act_fn, - pad_mode=pad_mode, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.downsamplers = None - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - CogVideoXDownsample3D( - out_channels, out_channels, padding=downsample_padding, compress_time=compress_time - ) - ] - ) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - zq: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq - ) - else: - hidden_states = resnet(hidden_states, temb, zq) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class CogVideoXMidBlock3D(nn.Module): - r""" - A middle block used in the CogVideoX model. - - Args: - in_channels (`int`): - Number of input channels. - temb_channels (`int`, defaults to `512`): - Number of time embedding channels. - dropout (`float`, defaults to `0.0`): - Dropout rate. - num_layers (`int`, defaults to `1`): - Number of resnet layers. - resnet_eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - resnet_act_fn (`str`, defaults to `"swish"`): - Activation function to use. - resnet_groups (`int`, defaults to `32`): - Number of groups to separate the channels into for group normalization. - spatial_norm_dim (`int`, *optional*): - The dimension to use for spatial norm if it is to be used instead of group norm. - pad_mode (str, defaults to `"first"`): - Padding mode. - """ - - _supports_gradient_checkpointing = True - - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - spatial_norm_dim: Optional[int] = None, - pad_mode: str = "first", - ): - super().__init__() - - resnets = [] - for _ in range(num_layers): - resnets.append( - CogVideoXResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=resnet_groups, - eps=resnet_eps, - spatial_norm_dim=spatial_norm_dim, - non_linearity=resnet_act_fn, - pad_mode=pad_mode, - ) - ) - self.resnets = nn.ModuleList(resnets) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - zq: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq - ) - else: - hidden_states = resnet(hidden_states, temb, zq) - - return hidden_states - - -class CogVideoXUpBlock3D(nn.Module): - r""" - An upsampling block used in the CogVideoX model. - - Args: - in_channels (`int`): - Number of input channels. - out_channels (`int`, *optional*): - Number of output channels. If None, defaults to `in_channels`. - temb_channels (`int`, defaults to `512`): - Number of time embedding channels. - dropout (`float`, defaults to `0.0`): - Dropout rate. - num_layers (`int`, defaults to `1`): - Number of resnet layers. - resnet_eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - resnet_act_fn (`str`, defaults to `"swish"`): - Activation function to use. - resnet_groups (`int`, defaults to `32`): - Number of groups to separate the channels into for group normalization. - spatial_norm_dim (`int`, defaults to `16`): - The dimension to use for spatial norm if it is to be used instead of group norm. - add_upsample (`bool`, defaults to `True`): - Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. - compress_time (`bool`, defaults to `False`): - Whether or not to downsample across temporal dimension. - pad_mode (str, defaults to `"first"`): - Padding mode. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - spatial_norm_dim: int = 16, - add_upsample: bool = True, - upsample_padding: int = 1, - compress_time: bool = False, - pad_mode: str = "first", - ): - super().__init__() - - resnets = [] - for i in range(num_layers): - in_channel = in_channels if i == 0 else out_channels - resnets.append( - CogVideoXResnetBlock3D( - in_channels=in_channel, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=resnet_groups, - eps=resnet_eps, - non_linearity=resnet_act_fn, - spatial_norm_dim=spatial_norm_dim, - pad_mode=pad_mode, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.upsamplers = None - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - CogVideoXUpsample3D( - out_channels, out_channels, padding=upsample_padding, compress_time=compress_time - ) - ] - ) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - zq: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r"""Forward method of the `CogVideoXUpBlock3D` class.""" - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq - ) - else: - hidden_states = resnet(hidden_states, temb, zq) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class CogVideoXEncoder3D(nn.Module): - r""" - The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available - options. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - """ - - _supports_gradient_checkpointing = True - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 16, - down_block_types: Tuple[str, ...] = ( - "CogVideoXDownBlock3D", - "CogVideoXDownBlock3D", - "CogVideoXDownBlock3D", - "CogVideoXDownBlock3D", - ), - block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - layers_per_block: int = 3, - act_fn: str = "silu", - norm_eps: float = 1e-6, - norm_num_groups: int = 32, - dropout: float = 0.0, - pad_mode: str = "first", - temporal_compression_ratio: float = 4, - ): - super().__init__() - - # log2 of temporal_compress_times - temporal_compress_level = int(np.log2(temporal_compression_ratio)) - - self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) - self.down_blocks = nn.ModuleList([]) - - # down blocks - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - compress_time = i < temporal_compress_level - - if down_block_type == "CogVideoXDownBlock3D": - down_block = CogVideoXDownBlock3D( - in_channels=input_channel, - out_channels=output_channel, - temb_channels=0, - dropout=dropout, - num_layers=layers_per_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - add_downsample=not is_final_block, - compress_time=compress_time, - ) - else: - raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") - - self.down_blocks.append(down_block) - - # mid block - self.mid_block = CogVideoXMidBlock3D( - in_channels=block_out_channels[-1], - temb_channels=0, - dropout=dropout, - num_layers=2, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - pad_mode=pad_mode, - ) - - self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = CogVideoXCausalConv3d( - block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - r"""The forward method of the `CogVideoXEncoder3D` class.""" - hidden_states = self.conv_in(sample) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # 1. Down - for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, temb, None - ) - - # 2. Mid - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, None - ) - else: - # 1. Down - for down_block in self.down_blocks: - hidden_states = down_block(hidden_states, temb, None) - - # 2. Mid - hidden_states = self.mid_block(hidden_states, temb, None) - - # 3. Post-process - hidden_states = self.norm_out(hidden_states) - hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) - return hidden_states - - -class CogVideoXDecoder3D(nn.Module): - r""" - The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output - sample. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - """ - - _supports_gradient_checkpointing = True - - def __init__( - self, - in_channels: int = 16, - out_channels: int = 3, - up_block_types: Tuple[str, ...] = ( - "CogVideoXUpBlock3D", - "CogVideoXUpBlock3D", - "CogVideoXUpBlock3D", - "CogVideoXUpBlock3D", - ), - block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), - layers_per_block: int = 3, - act_fn: str = "silu", - norm_eps: float = 1e-6, - norm_num_groups: int = 32, - dropout: float = 0.0, - pad_mode: str = "first", - temporal_compression_ratio: float = 4, - ): - super().__init__() - - reversed_block_out_channels = list(reversed(block_out_channels)) - - self.conv_in = CogVideoXCausalConv3d( - in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode - ) - - # mid block - self.mid_block = CogVideoXMidBlock3D( - in_channels=reversed_block_out_channels[0], - temb_channels=0, - num_layers=2, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - spatial_norm_dim=in_channels, - pad_mode=pad_mode, - ) - - # up blocks - self.up_blocks = nn.ModuleList([]) - - output_channel = reversed_block_out_channels[0] - temporal_compress_level = int(np.log2(temporal_compression_ratio)) - - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - compress_time = i < temporal_compress_level - - if up_block_type == "CogVideoXUpBlock3D": - up_block = CogVideoXUpBlock3D( - in_channels=prev_output_channel, - out_channels=output_channel, - temb_channels=0, - dropout=dropout, - num_layers=layers_per_block + 1, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - spatial_norm_dim=in_channels, - add_upsample=not is_final_block, - compress_time=compress_time, - pad_mode=pad_mode, - ) - prev_output_channel = output_channel - else: - raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") - - self.up_blocks.append(up_block) - - self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) - self.conv_act = nn.SiLU() - self.conv_out = CogVideoXCausalConv3d( - reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - r"""The forward method of the `CogVideoXDecoder3D` class.""" - hidden_states = self.conv_in(sample) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # 1. Mid - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, sample - ) - - # 2. Up - for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, temb, sample - ) - else: - # 1. Mid - hidden_states = self.mid_block(hidden_states, temb, sample) - - # 2. Up - for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb, sample) - - # 3. Post-process - hidden_states = self.norm_out(hidden_states, sample) - hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) - return hidden_states - - -class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): - r""" - A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in - [CogVideoX](https://github.com/THUDM/CogVideo). - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): - Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - sample_size (`int`, *optional*, defaults to `32`): Sample input size. - scaling_factor (`float`, *optional*, defaults to `1.15258426`): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - force_upcast (`bool`, *optional*, default to `True`): - If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE - can be fine-tuned / trained to a lower range without loosing too much precision in which case - `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - """ - - _supports_gradient_checkpointing = True - _no_split_modules = ["CogVideoXResnetBlock3D"] - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ( - "CogVideoXDownBlock3D", - "CogVideoXDownBlock3D", - "CogVideoXDownBlock3D", - "CogVideoXDownBlock3D", - ), - up_block_types: Tuple[str] = ( - "CogVideoXUpBlock3D", - "CogVideoXUpBlock3D", - "CogVideoXUpBlock3D", - "CogVideoXUpBlock3D", - ), - block_out_channels: Tuple[int] = (128, 256, 256, 512), - latent_channels: int = 16, - layers_per_block: int = 3, - act_fn: str = "silu", - norm_eps: float = 1e-6, - norm_num_groups: int = 32, - temporal_compression_ratio: float = 4, - sample_height: int = 480, - sample_width: int = 720, - scaling_factor: float = 1.15258426, - shift_factor: Optional[float] = None, - latents_mean: Optional[Tuple[float]] = None, - latents_std: Optional[Tuple[float]] = None, - force_upcast: float = True, - use_quant_conv: bool = False, - use_post_quant_conv: bool = False, - ): - super().__init__() - - self.encoder = CogVideoXEncoder3D( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_eps=norm_eps, - norm_num_groups=norm_num_groups, - temporal_compression_ratio=temporal_compression_ratio, - ) - self.decoder = CogVideoXDecoder3D( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_eps=norm_eps, - norm_num_groups=norm_num_groups, - temporal_compression_ratio=temporal_compression_ratio, - ) - self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None - self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None - - self.use_slicing = False - self.use_tiling = False - - # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not - # recommended because the temporal parts of the VAE, here, are tricky to understand. - # If you decode X latent frames together, the number of output frames is: - # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames - # - # Example with num_latent_frames_batch_size = 2: - # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together - # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) - # => 6 * 8 = 48 frames - # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together - # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + - # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) - # => 1 * 9 + 5 * 8 = 49 frames - # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that - # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different - # number of temporal frames. - self.num_latent_frames_batch_size = 2 - - # We make the minimum height and width of sample for tiling half that of the generally supported - self.tile_sample_min_height = sample_height // 2 - self.tile_sample_min_width = sample_width // 2 - self.tile_latent_min_height = int( - self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) - ) - self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) - - # These are experimental overlap factors that were chosen based on experimentation and seem to work best for - # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX - # and so the tiling implementation has only been tested on those specific resolutions. - self.tile_overlap_factor_height = 1 / 6 - self.tile_overlap_factor_width = 1 / 5 - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): - module.gradient_checkpointing = value - - def _clear_fake_context_parallel_cache(self): - for name, module in self.named_modules(): - if isinstance(module, CogVideoXCausalConv3d): - logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") - module._clear_fake_context_parallel_cache() - - def enable_tiling( - self, - tile_sample_min_height: Optional[int] = None, - tile_sample_min_width: Optional[int] = None, - tile_overlap_factor_height: Optional[float] = None, - tile_overlap_factor_width: Optional[float] = None, - ) -> None: - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - - Args: - tile_sample_min_height (`int`, *optional*): - The minimum height required for a sample to be separated into tiles across the height dimension. - tile_sample_min_width (`int`, *optional*): - The minimum width required for a sample to be separated into tiles across the width dimension. - tile_overlap_factor_height (`int`, *optional*): - The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are - no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher - value might cause more tiles to be processed leading to slow down of the decoding process. - tile_overlap_factor_width (`int`, *optional*): - The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there - are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher - value might cause more tiles to be processed leading to slow down of the decoding process. - """ - self.use_tiling = True - self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height - self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width - self.tile_latent_min_height = int( - self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) - ) - self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) - self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height - self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width - - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - - @apply_forward_hook - def encode( - self, x: torch.Tensor, return_dict: bool = True - ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: - """ - Encode a batch of images into latents. - - Args: - x (`torch.Tensor`): Input batch of images. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. - - Returns: - The latent representations of the encoded images. If `return_dict` is True, a - [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. - """ - batch_size, num_channels, num_frames, height, width = x.shape - if num_frames == 1: - h = self.encoder(x) - if self.quant_conv is not None: - h = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(h) - else: - frame_batch_size = 4 - h = [] - for i in range(num_frames // frame_batch_size): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) - end_frame = frame_batch_size * (i + 1) + remaining_frames - z_intermediate = x[:, :, start_frame:end_frame] - z_intermediate = self.encoder(z_intermediate) - if self.quant_conv is not None: - z_intermediate = self.quant_conv(z_intermediate) - h.append(z_intermediate) - self._clear_fake_context_parallel_cache() - h = torch.cat(h, dim=2) - posterior = DiagonalGaussianDistribution(h) - if not return_dict: - return (posterior,) - return AutoencoderKLOutput(latent_dist=posterior) - - def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - batch_size, num_channels, num_frames, height, width = z.shape - - if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): - return self.tiled_decode(z, return_dict=return_dict) - - if num_frames == 1: - dec = [] - z_intermediate = z - if self.post_quant_conv is not None: - z_intermediate = self.post_quant_conv(z_intermediate) - z_intermediate = self.decoder(z_intermediate) - dec.append(z_intermediate) - else: - frame_batch_size = self.num_latent_frames_batch_size - dec = [] - for i in range(num_frames // frame_batch_size): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) - end_frame = frame_batch_size * (i + 1) + remaining_frames - z_intermediate = z[:, :, start_frame:end_frame] - if self.post_quant_conv is not None: - z_intermediate = self.post_quant_conv(z_intermediate) - z_intermediate = self.decoder(z_intermediate) - dec.append(z_intermediate) - - self._clear_fake_context_parallel_cache() - dec = torch.cat(dec, dim=2) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - @apply_forward_hook - def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - """ - Decode a batch of images. - - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - """ - if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] - decoded = torch.cat(decoded_slices) - else: - decoded = self._decode(z).sample - - if not return_dict: - return (decoded,) - return DecoderOutput(sample=decoded) - - def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: - blend_extent = min(a.shape[3], b.shape[3], blend_extent) - for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( - y / blend_extent - ) - return b - - def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: - blend_extent = min(a.shape[4], b.shape[4], blend_extent) - for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( - x / blend_extent - ) - return b - - def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - r""" - Decode a batch of images using a tiled decoder. - - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - """ - # Rough memory assessment: - # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. - # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. - # - Assume fp16 (2 bytes per value). - # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB - # - # Memory assessment when using tiling: - # - Assume everything as above but now HxW is 240x360 by tiling in half - # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB - - batch_size, num_channels, num_frames, height, width = z.shape - - overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) - overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) - blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) - blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) - row_limit_height = self.tile_sample_min_height - blend_extent_height - row_limit_width = self.tile_sample_min_width - blend_extent_width - frame_batch_size = self.num_latent_frames_batch_size - - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, height, overlap_height): - row = [] - for j in range(0, width, overlap_width): - time = [] - for k in range(num_frames // frame_batch_size): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) - end_frame = frame_batch_size * (k + 1) + remaining_frames - tile = z[ - :, - :, - start_frame:end_frame, - i : i + self.tile_latent_min_height, - j : j + self.tile_latent_min_width, - ] - if self.post_quant_conv is not None: - tile = self.post_quant_conv(tile) - tile = self.decoder(tile) - time.append(tile) - self._clear_fake_context_parallel_cache() - row.append(torch.cat(time, dim=2)) - rows.append(row) - - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent_width) - result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) - result_rows.append(torch.cat(result_row, dim=4)) - - dec = torch.cat(result_rows, dim=3) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[torch.Tensor, torch.Tensor]: - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z) - if not return_dict: - return (dec,) - return dec diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py deleted file mode 100644 index f598147..0000000 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ /dev/null @@ -1,866 +0,0 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from einops import rearrange - -from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel -from diffusers.models.embeddings import get_3d_rotary_pos_embed -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from diffusers.utils import BaseOutput, logging, replace_example_docstring -from diffusers.utils.torch_utils import randn_tensor -from diffusers.video_processor import VideoProcessor -from diffusers.image_processor import VaeImageProcessor -from einops import rearrange - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```python - >>> import torch - >>> from diffusers import CogVideoX_Fun_Pipeline - >>> from diffusers.utils import export_to_video - - >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" - >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") - >>> prompt = ( - ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " - ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " - ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " - ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " - ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " - ... "atmosphere of this unique musical performance." - ... ) - >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] - >>> export_to_video(video, "output.mp4", fps=8) - ``` -""" - - -# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid -def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): - tw = tgt_width - th = tgt_height - h, w = src - r = h / w - if r > (th / tw): - resize_height = th - resize_width = int(round(th / h * w)) - else: - resize_width = tw - resize_height = int(round(tw / w * h)) - - crop_top = int(round((th - resize_height) / 2.0)) - crop_left = int(round((tw - resize_width) / 2.0)) - - return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -@dataclass -class CogVideoX_Fun_PipelineOutput(BaseOutput): - r""" - Output class for CogVideo pipelines. - - Args: - video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): - List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing - denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape - `(batch_size, num_frames, channels, height, width)`. - """ - - videos: torch.Tensor - - -class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline): - r""" - Pipeline for text-to-video generation using CogVideoX. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - transformer ([`CogVideoXTransformer3DModel`]): - A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `transformer` to denoise the encoded video latents. - """ - - _optional_components = [] - model_cpu_offload_seq = "vae->transformer->vae" - - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - ] - - def __init__( - self, - vae: AutoencoderKLCogVideoX, - transformer: CogVideoXTransformer3DModel, - scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], - ): - super().__init__() - - self.register_modules( - vae=vae, transformer=transformer, scheduler=scheduler - ) - self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True - ) - - def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, - latents=None, freenoise=True, context_size=None, context_overlap=None - ): - shape = ( - batch_size, - (num_frames - 1) // self.vae_scale_factor_temporal + 1, - num_channels_latents, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype) - if freenoise: - print("Applying FreeNoise") - # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) - video_length = num_frames // 4 - delta = context_size - context_overlap - for start_idx in range(0, video_length-context_size, delta): - # start_idx corresponds to the beginning of a context window - # goal: place shuffled in the delta region right after the end of the context window - # if space after context window is not enough to place the noise, adjust and finish - place_idx = start_idx + context_size - # if place_idx is outside the valid indexes, we are already finished - if place_idx >= video_length: - break - end_idx = place_idx - 1 - #print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta) - - # if there is not enough room to copy delta amount of indexes, copy limited amount and finish - if end_idx + delta >= video_length: - final_delta = video_length - place_idx - # generate list of indexes in final delta region - list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) - # shuffle list - list_idx = list_idx[torch.randperm(final_delta, generator=generator)] - # apply shuffled indexes - noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :] - break - # otherwise, do normal behavior - # generate list of indexes in delta region - list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) - # shuffle list - list_idx = list_idx[torch.randperm(delta, generator=generator)] - # apply shuffled indexes - #print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx) - noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :] - if latents is None: - latents = noise.to(device) - else: - latents = latents.to(device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) - latent_timestep = timesteps[:1] - - noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) - frames_needed = noise.shape[1] - current_frames = latents.shape[1] - - if frames_needed > current_frames: - repeat_factor = frames_needed // current_frames - additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device) - latents = torch.cat((latents, additional_frame), dim=1) - elif frames_needed < current_frames: - latents = latents[:, :frames_needed, :, :, :] - - latents = self.scheduler.add_noise(latents, noise, latent_timestep) - latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler - return latents, timesteps, noise - - def prepare_control_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - - if mask is not None: - mask = mask.to(device=device, dtype=self.vae.dtype) - bs = 1 - new_mask = [] - for i in range(0, mask.shape[0], bs): - mask_bs = mask[i : i + bs] - mask_bs = self.vae.encode(mask_bs)[0] - mask_bs = mask_bs.mode() - new_mask.append(mask_bs) - mask = torch.cat(new_mask, dim = 0) - mask = mask * self.vae.config.scaling_factor - - if masked_image is not None: - masked_image = masked_image.to(device=device, dtype=self.vae.dtype) - bs = 1 - new_mask_pixel_values = [] - for i in range(0, masked_image.shape[0], bs): - mask_pixel_values_bs = masked_image[i : i + bs] - mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.mode() - new_mask_pixel_values.append(mask_pixel_values_bs) - masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) - masked_image_latents = masked_image_latents * self.vae.config.scaling_factor - else: - masked_image_latents = None - - return mask, masked_image_latents - - def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: - latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents - - frames = self.vae.decode(latents).sample - frames = (frames / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - frames = frames.cpu().float().numpy() - return frames - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def fuse_qkv_projections(self) -> None: - r"""Enables fused QKV projections.""" - self.fusing_transformer = True - self.transformer.fuse_qkv_projections() - - def unfuse_qkv_projections(self) -> None: - r"""Disable QKV projection fusion if enabled.""" - if not self.fusing_transformer: - logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") - else: - self.transformer.unfuse_qkv_projections() - self.fusing_transformer = False - - def _prepare_rotary_positional_embeddings( - self, - height: int, - width: int, - num_frames: int, - device: torch.device, - start_frame: Optional[int] = None, - end_frame: Optional[int] = None, - context_frames: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=num_frames, - use_real=True, - ) - - if start_frame is not None or context_frames is not None: - freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1) - freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1) - if context_frames is not None: - freqs_cos = freqs_cos[context_frames] - freqs_sin = freqs_sin[context_frames] - else: - freqs_cos = freqs_cos[start_frame:end_frame] - freqs_sin = freqs_sin[start_frame:end_frame] - - freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1]) - freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1]) - - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) - return freqs_cos, freqs_sin - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - - return timesteps, num_inference_steps - t_start - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, - video: Union[torch.FloatTensor] = None, - control_video: Union[torch.FloatTensor] = None, - num_frames: int = 49, - num_inference_steps: int = 50, - timesteps: Optional[List[int]] = None, - guidance_scale: float = 6, - use_dynamic_cfg: bool = False, - denoise_strength: float = 1.0, - num_videos_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 226, - comfyui_progressbar: bool = False, - control_strength: float = 1.0, - control_start_percent: float = 0.0, - control_end_percent: float = 1.0, - scheduler_name: str = "DPM", - context_schedule: Optional[str] = None, - context_frames: Optional[int] = None, - context_stride: Optional[int] = None, - context_overlap: Optional[int] = None, - freenoise: Optional[bool] = True, - tora: Optional[dict] = None, - ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_frames (`int`, defaults to `48`): - Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will - contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where - num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that - needs to be satisfied is that of divisibility mentioned above. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int`, defaults to `226`): - Maximum sequence length in encoded prompt. Must be consistent with - `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. - - Examples: - - Returns: - [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`: - [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - - # if num_frames > 49: - # raise ValueError( - # "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - # ) - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial - num_videos_per_prompt = 1 - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds, - negative_prompt_embeds, - ) - self._guidance_scale = guidance_scale - self._interrupt = False - - # 2. Default call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - self._num_timesteps = len(timesteps) - if comfyui_progressbar: - from comfy.utils import ProgressBar - pbar = ProgressBar(num_inference_steps + 2) - - # 5. Prepare latents. - latent_channels = self.vae.config.latent_channels - latents, timesteps, noise = self.prepare_latents( - batch_size * num_videos_per_prompt, - latent_channels, - num_frames, - height, - width, - self.vae.dtype, - device, - generator, - timesteps, - denoise_strength, - num_inference_steps, - latents, - context_size=context_frames, - context_overlap=context_overlap, - freenoise=freenoise, - ) - if comfyui_progressbar: - pbar.update(1) - - - control_video_latents_input = ( - torch.cat([control_video] * 2) if do_classifier_free_guidance else control_video - ) - control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w") - - control_latents = control_latents * control_strength - - if comfyui_progressbar: - pbar.update(1) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - - - # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - if context_schedule is not None: - print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") - use_context_schedule = True - from .context import get_context_scheduler - context = get_context_scheduler(context_schedule) - - else: - use_context_schedule = False - print(" context schedule disabled") - # 7. Create rotary embeds if required - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - if tora is not None and do_classifier_free_guidance: - video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous() - - if tora is not None: - for module in self.transformer.fuser_list: - for param in module.parameters(): - param.data = param.data.to(device) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - # for DPM-solver++ - old_pred_original_sample = None - for i, t in enumerate(timesteps): - if self.interrupt: - continue - if use_context_schedule: - - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # Calculate the current step percentage - current_step_percentage = i / num_inference_steps - - # Determine if control_latents should be applied - apply_control = control_start_percent <= current_step_percentage <= control_end_percent - current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - context_queue = list(context( - i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, - )) - counter = torch.zeros_like(latent_model_input) - noise_pred = torch.zeros_like(latent_model_input) - - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, context_frames, device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - - for c in context_queue: - partial_latent_model_input = latent_model_input[:, c, :, :, :] - partial_control_latents = current_control_latents[:, c, :, :, :] - - # predict noise model_output - noise_pred[:, c, :, :, :] += self.transformer( - hidden_states=partial_latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - control_latents=partial_control_latents, - )[0] - - counter[:, c, :, :, :] += 1 - noise_pred = noise_pred.float() - - noise_pred /= counter - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, - ) - latents = latents.to(prompt_embeds.dtype) - - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if comfyui_progressbar: - pbar.update(1) - else: - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # Calculate the current step percentage - current_step_percentage = i / num_inference_steps - - # Determine if control_latents should be applied - apply_control = control_start_percent <= current_step_percentage <= control_end_percent - current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - control_latents=current_control_latents, - video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, - - )[0] - noise_pred = noise_pred.float() - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, - ) - latents = latents.to(prompt_embeds.dtype) - - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if comfyui_progressbar: - pbar.update(1) - - # if output_type == "numpy": - # video = self.decode_latents(latents) - # elif not output_type == "latent": - # video = self.decode_latents(latents) - # video = self.video_processor.postprocess_video(video=video, output_type=output_type) - # else: - # video = latents - - # Offload all models - self.maybe_free_model_hooks() - - # if not return_dict: - # video = torch.from_numpy(video) - - return latents \ No newline at end of file diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py deleted file mode 100644 index a6f0e9e..0000000 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ /dev/null @@ -1,1037 +0,0 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from einops import rearrange - -from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel -from diffusers.models.embeddings import get_3d_rotary_pos_embed -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from diffusers.utils import BaseOutput, logging, replace_example_docstring -from diffusers.utils.torch_utils import randn_tensor -from diffusers.video_processor import VideoProcessor -from diffusers.image_processor import VaeImageProcessor -from einops import rearrange - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```python - >>> import torch - >>> from diffusers import CogVideoX_Fun_Pipeline - >>> from diffusers.utils import export_to_video - - >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" - >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") - >>> prompt = ( - ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " - ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " - ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " - ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " - ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " - ... "atmosphere of this unique musical performance." - ... ) - >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] - >>> export_to_video(video, "output.mp4", fps=8) - ``` -""" - - -# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid -def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): - tw = tgt_width - th = tgt_height - h, w = src - r = h / w - if r > (th / tw): - resize_height = th - resize_width = int(round(th / h * w)) - else: - resize_width = tw - resize_height = int(round(tw / w * h)) - - crop_top = int(round((th - resize_height) / 2.0)) - crop_left = int(round((tw - resize_width) / 2.0)) - - return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -def resize_mask(mask, latent, process_first_frame_only=True): - latent_size = latent.size() - batch_size, channels, num_frames, height, width = mask.shape - - if process_first_frame_only: - target_size = list(latent_size[2:]) - target_size[0] = 1 - first_frame_resized = F.interpolate( - mask[:, :, 0:1, :, :], - size=target_size, - mode='trilinear', - align_corners=False - ) - - target_size = list(latent_size[2:]) - target_size[0] = target_size[0] - 1 - if target_size[0] != 0: - remaining_frames_resized = F.interpolate( - mask[:, :, 1:, :, :], - size=target_size, - mode='trilinear', - align_corners=False - ) - resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) - else: - resized_mask = first_frame_resized - else: - target_size = list(latent_size[2:]) - resized_mask = F.interpolate( - mask, - size=target_size, - mode='trilinear', - align_corners=False - ) - return resized_mask - -def add_noise_to_reference_video(image, ratio=None): - if ratio is None: - sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) - sigma = torch.exp(sigma).to(image.dtype) - else: - sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio - - image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] - image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) - image = image + image_noise - return image - -@dataclass -class CogVideoX_Fun_PipelineOutput(BaseOutput): - r""" - Output class for CogVideo pipelines. - - Args: - video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): - List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing - denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape - `(batch_size, num_frames, channels, height, width)`. - """ - - videos: torch.Tensor - - -class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline): - r""" - Pipeline for text-to-video generation using CogVideoX. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - transformer ([`CogVideoXTransformer3DModel`]): - A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `transformer` to denoise the encoded video latents. - """ - - _optional_components = [] - model_cpu_offload_seq = "vae->transformer->vae" - - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - ] - - def __init__( - self, - vae: AutoencoderKLCogVideoX, - transformer: CogVideoXTransformer3DModel, - scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], - ): - super().__init__() - - self.register_modules( - vae=vae, transformer=transformer, scheduler=scheduler - ) - self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True - ) - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - video_length, - dtype, - device, - generator, - latents=None, - video=None, - timestep=None, - is_strength_max=True, - return_noise=False, - return_video_latents=False, - context_size=None, - context_overlap=None, - freenoise=False, - ): - shape = ( - batch_size, - (video_length - 1) // self.vae_scale_factor_temporal + 1, - num_channels_latents, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if return_video_latents or (latents is None and not is_strength_max): - video = video.to(device=device, dtype=self.vae.dtype) - - bs = 1 - new_video = [] - for i in range(0, video.shape[0], bs): - video_bs = video[i : i + bs] - video_bs = self.vae.encode(video_bs)[0] - video_bs = video_bs.sample() - new_video.append(video_bs) - video = torch.cat(new_video, dim = 0) - video = video * self.vae.config.scaling_factor - - video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) - video_latents = video_latents.to(device=device, dtype=dtype) - video_latents = rearrange(video_latents, "b c f h w -> b f c h w") - - if latents is None: - noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=dtype) - if freenoise: - print("Applying FreeNoise") - # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) - video_length = video_length // 4 - delta = context_size - context_overlap - for start_idx in range(0, video_length-context_size, delta): - # start_idx corresponds to the beginning of a context window - # goal: place shuffled in the delta region right after the end of the context window - # if space after context window is not enough to place the noise, adjust and finish - place_idx = start_idx + context_size - # if place_idx is outside the valid indexes, we are already finished - if place_idx >= video_length: - break - end_idx = place_idx - 1 - #print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta) - - # if there is not enough room to copy delta amount of indexes, copy limited amount and finish - if end_idx + delta >= video_length: - final_delta = video_length - place_idx - # generate list of indexes in final delta region - list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) - # shuffle list - list_idx = list_idx[torch.randperm(final_delta, generator=generator)] - # apply shuffled indexes - noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :] - break - # otherwise, do normal behavior - # generate list of indexes in delta region - list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) - # shuffle list - list_idx = list_idx[torch.randperm(delta, generator=generator)] - # apply shuffled indexes - #print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx) - noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :] - - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(video_latents.to(noise), noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents - latents = latents.to(device) - else: - noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma - - # scale the initial noise by the standard deviation required by the scheduler - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_video_latents: - outputs += (video_latents,) - - return outputs - - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - - if mask is not None: - mask = mask.to(device=device, dtype=self.vae.dtype) - bs = 1 - new_mask = [] - for i in range(0, mask.shape[0], bs): - mask_bs = mask[i : i + bs] - mask_bs = self.vae.encode(mask_bs)[0] - mask_bs = mask_bs.mode() - new_mask.append(mask_bs) - mask = torch.cat(new_mask, dim = 0) - mask = mask * self.vae.config.scaling_factor - - if masked_image is not None: - if self.transformer.config.add_noise_in_inpaint_model: - masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) - masked_image = masked_image.to(device=device, dtype=self.vae.dtype) - bs = 1 - new_mask_pixel_values = [] - for i in range(0, masked_image.shape[0], bs): - mask_pixel_values_bs = masked_image[i : i + bs] - mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.mode() - new_mask_pixel_values.append(mask_pixel_values_bs) - masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) - masked_image_latents = masked_image_latents * self.vae.config.scaling_factor - else: - masked_image_latents = None - - return mask, masked_image_latents - - def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: - latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents - - frames = self.vae.decode(latents).sample - frames = (frames / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - frames = frames.cpu().float().numpy() - return frames - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def fuse_qkv_projections(self) -> None: - r"""Enables fused QKV projections.""" - self.fusing_transformer = True - self.transformer.fuse_qkv_projections() - - def unfuse_qkv_projections(self) -> None: - r"""Disable QKV projection fusion if enabled.""" - if not self.fusing_transformer: - logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") - else: - self.transformer.unfuse_qkv_projections() - self.fusing_transformer = False - - def _prepare_rotary_positional_embeddings( - self, - height: int, - width: int, - num_frames: int, - device: torch.device, - start_frame: Optional[int] = None, - end_frame: Optional[int] = None, - context_frames: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=num_frames, - use_real=True, - ) - - if start_frame is not None or context_frames is not None: - freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1) - freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1) - if context_frames is not None: - freqs_cos = freqs_cos[context_frames] - freqs_sin = freqs_sin[context_frames] - else: - freqs_cos = freqs_cos[start_frame:end_frame] - freqs_sin = freqs_sin[start_frame:end_frame] - - freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1]) - freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1]) - - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) - return freqs_cos, freqs_sin - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - - return timesteps, num_inference_steps - t_start - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, - video: Union[torch.FloatTensor] = None, - mask_video: Union[torch.FloatTensor] = None, - masked_video_latents: Union[torch.FloatTensor] = None, - num_frames: int = 49, - num_inference_steps: int = 50, - timesteps: Optional[List[int]] = None, - guidance_scale: float = 6, - use_dynamic_cfg: bool = False, - num_videos_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 226, - strength: float = 1, - noise_aug_strength: float = 0.0563, - comfyui_progressbar: bool = False, - context_schedule: Optional[str] = None, - context_frames: Optional[int] = None, - context_stride: Optional[int] = None, - context_overlap: Optional[int] = None, - freenoise: Optional[bool] = True, - tora: Optional[dict] = None, - ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_frames (`int`, defaults to `48`): - Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will - contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where - num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that - needs to be satisfied is that of divisibility mentioned above. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int`, defaults to `226`): - Maximum sequence length in encoded prompt. Must be consistent with - `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. - - Examples: - - Returns: - [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`: - [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - - # if num_frames > 49: - # raise ValueError( - # "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - # ) - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial - num_videos_per_prompt = 1 - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds, - negative_prompt_embeds, - ) - self._guidance_scale = guidance_scale - self._interrupt = False - - # 2. Default call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - - # 4. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps=num_inference_steps, strength=strength, device=device - ) - self._num_timesteps = len(timesteps) - if comfyui_progressbar: - from comfy.utils import ProgressBar - pbar = ProgressBar(num_inference_steps + 2) - # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise - is_strength_max = strength == 1.0 - - # 5. Prepare latents. - if video is not None: - video_length = video.shape[2] - init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) - init_video = init_video.to(dtype=torch.float32) - init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) - else: - init_video = None - - num_channels_latents = self.vae.config.latent_channels - num_channels_transformer = self.transformer.config.in_channels - return_image_latents = num_channels_transformer == num_channels_latents - - self.vae.to(device) - - latents_outputs = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - video_length, - self.vae.dtype, - device, - generator, - latents, - video=init_video, - timestep=latent_timestep, - is_strength_max=is_strength_max, - return_noise=True, - return_video_latents=return_image_latents, - context_size=context_frames, - context_overlap=context_overlap, - freenoise=freenoise, - ) - if return_image_latents: - latents, noise, image_latents = latents_outputs - else: - latents, noise = latents_outputs - if comfyui_progressbar: - pbar.update(1) - - if mask_video is not None: - if (mask_video == 255).all(): - mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype) - masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) - - mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents - masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents - ) - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype) - else: - # Prepare mask latent variables - video_length = video.shape[2] - mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) - mask_condition = mask_condition.to(dtype=torch.float32) - mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) - - if num_channels_transformer != num_channels_latents: - mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) - if masked_video_latents is None: - masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 - else: - masked_video = masked_video_latents - - _, masked_video_latents = self.prepare_mask_latents( - None, - masked_video, - batch_size, - height, - width, - self.vae.dtype, - device, - generator, - do_classifier_free_guidance, - noise_aug_strength=noise_aug_strength, - ) - mask_latents = resize_mask(1 - mask_condition, masked_video_latents) - mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor - - mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) - - mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents - masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents - ) - - mask = rearrange(mask, "b c f h w -> b f c h w") - mask_input = rearrange(mask_input, "b c f h w -> b f c h w") - masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w") - - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype) - else: - mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) - mask = rearrange(mask, "b c f h w -> b f c h w") - - inpaint_latents = None - else: - if num_channels_transformer != num_channels_latents: - mask = torch.zeros_like(latents).to(latents.device, latents.dtype) - masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) - - mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask - masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents - ) - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) - else: - mask = torch.zeros_like(init_video[:, :1]) - mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) - mask = rearrange(mask, "b c f h w -> b f c h w") - - inpaint_latents = None - - self.vae.to(torch.device("cpu")) - - if comfyui_progressbar: - pbar.update(1) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Create rotary embeds if required - if context_schedule is not None: - print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") - use_context_schedule = True - from .context import get_context_scheduler - context = get_context_scheduler(context_schedule) - else: - use_context_schedule = False - print("context schedule disabled") - # 7. Create rotary embeds if required - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - if tora is not None and do_classifier_free_guidance: - video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous() - - if tora is not None: - trajectory_length = tora["video_flow_features"].shape[1] - logger.info(f"Tora trajectory length: {trajectory_length}") - logger.info(f"Tora trajectory shape: {tora['video_flow_features'].shape}") - logger.info(f"latents shape: {latents.shape}") - if trajectory_length != latents.shape[1]: - raise ValueError(f"Tora trajectory length {trajectory_length} does not match latent count {latents.shape[2]}") - for module in self.transformer.fuser_list: - for param in module.parameters(): - param.data = param.data.to(device) - - # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - from ..latent_preview import prepare_callback - callback = prepare_callback(self.transformer, num_inference_steps) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - # for DPM-solver++ - old_pred_original_sample = None - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - if use_context_schedule: - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - context_queue = list(context( - i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, - )) - counter = torch.zeros_like(latent_model_input) - noise_pred = torch.zeros_like(latent_model_input) - - current_step_percentage = i / num_inference_steps - - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, context_frames, device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - - for c in context_queue: - partial_latent_model_input = latent_model_input[:, c, :, :, :] - partial_inpaint_latents = inpaint_latents[:, c, :, :, :] - partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :] - if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]): - if do_classifier_free_guidance: - partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous() - else: - partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :] - else: - partial_video_flow_features = None - - # predict noise model_output - noise_pred[:, c, :, :, :] += self.transformer( - hidden_states=partial_latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - inpaint_latents=partial_inpaint_latents, - video_flow_features=partial_video_flow_features - )[0] - - counter[:, c, :, :, :] += 1 - - noise_pred = noise_pred.float() - - noise_pred /= counter - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, - ) - latents = latents.to(prompt_embeds.dtype) - - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if comfyui_progressbar: - pbar.update(1) - - else: - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - current_step_percentage = i / num_inference_steps - - # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - inpaint_latents=inpaint_latents, - video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, - - )[0] - noise_pred = noise_pred.float() - - # perform guidance - if use_dynamic_cfg: - self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 - ) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, - ) - latents = latents.to(prompt_embeds.dtype) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if comfyui_progressbar: - if callback is not None: - callback(i, latents.detach()[-1], None, num_inference_steps) - else: - pbar.update(1) - - # Offload all models - self.maybe_free_model_hooks() - - return latents \ No newline at end of file diff --git a/cogvideox_fun/transformer_3d.py b/cogvideox_fun/transformer_3d.py deleted file mode 100644 index 5b6fef9..0000000 --- a/cogvideox_fun/transformer_3d.py +++ /dev/null @@ -1,823 +0,0 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Optional, Tuple, Union - -import os -import json -import torch -import glob -import torch.nn.functional as F -from torch import nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import is_torch_version, logging -from diffusers.utils.torch_utils import maybe_allow_in_graph -from diffusers.models.attention import Attention, FeedForward -from diffusers.models.attention_processor import AttentionProcessor#, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -from einops import rearrange -try: - from sageattention import sageattn - SAGEATTN_IS_AVAILABLE = True -except: - SAGEATTN_IS_AVAILABLE = False - -def fft(tensor): - tensor_fft = torch.fft.fft2(tensor) - tensor_fft_shifted = torch.fft.fftshift(tensor_fft) - B, C, H, W = tensor.size() - radius = min(H, W) // 5 - - Y, X = torch.meshgrid(torch.arange(H), torch.arange(W)) - center_x, center_y = W // 2, H // 2 - mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 - low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) - high_freq_mask = ~low_freq_mask - - low_freq_fft = tensor_fft_shifted * low_freq_mask - high_freq_fft = tensor_fft_shifted * high_freq_mask - - return low_freq_fft, high_freq_fft - -class CogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - attention_mode: Optional[str] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - from diffusers.models.embeddings import apply_rotary_emb - - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - if attention_mode == "sageattn": - if SAGEATTN_IS_AVAILABLE: - hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) - else: - raise ImportError("sageattn not found") - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - -class CogVideoXPatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - embed_dim: int = 1920, - text_embed_dim: int = 4096, - bias: bool = True, - ) -> None: - super().__init__() - self.patch_size = patch_size - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - self.text_proj = nn.Linear(text_embed_dim, embed_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - r""" - Args: - text_embeds (`torch.Tensor`): - Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). - image_embeds (`torch.Tensor`): - Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). - """ - text_embeds = self.text_proj(text_embeds) - - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) - image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) - image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] - image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] - - embeds = torch.cat( - [text_embeds, image_embeds], dim=1 - ).contiguous() # [batch, seq_length + num_frames x height x width, channels] - return embeds - -@maybe_allow_in_graph -class CogVideoXBlock(nn.Module): - r""" - Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. - - Parameters: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - time_embed_dim (`int`): - The number of channels in timestep embedding. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to be used in feed-forward. - attention_bias (`bool`, defaults to `False`): - Whether or not to use bias in attention projection layers. - qk_norm (`bool`, defaults to `True`): - Whether or not to use normalization after query and key projections in Attention. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_eps (`float`, defaults to `1e-5`): - Epsilon value for normalization layers. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. - ff_inner_dim (`int`, *optional*, defaults to `None`): - Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. - ff_bias (`bool`, defaults to `True`): - Whether or not to use bias in Feed-forward layer. - attention_out_bias (`bool`, defaults to `True`): - Whether or not to use bias in Attention output projection layer. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - time_embed_dim: int, - dropout: float = 0.0, - activation_fn: str = "gelu-approximate", - attention_bias: bool = False, - qk_norm: bool = True, - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - final_dropout: bool = True, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - attention_mode: Optional[str] = None, - ): - super().__init__() - - # 1. Self Attention - self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.attn1 = Attention( - query_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - qk_norm="layer_norm" if qk_norm else None, - eps=1e-6, - bias=attention_bias, - out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), - ) - - # 2. Feed Forward - self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - self.cached_hidden_states = [] - self.cached_encoder_hidden_states = [] - self.attention_mode = attention_mode - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - video_flow_feature: Optional[torch.Tensor] = None, - fuser=None, - block_use_fastercache=False, - fastercache_counter=0, - fastercache_start_step=15, - fastercache_device="cuda:0", - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( - hidden_states, encoder_hidden_states, temb - ) - # Tora Motion-guidance Fuser - if video_flow_feature is not None: - H, W = video_flow_feature.shape[-2:] - T = norm_hidden_states.shape[1] // H // W - h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W) - h = fuser(h, video_flow_feature.to(h), T=T) - norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) - del h, fuser - - #region fastercache - if block_use_fastercache: - B = norm_hidden_states.shape[0] - if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B: - attn_hidden_states = ( - self.cached_hidden_states[1][:B] + - (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B]) - * 0.3 - ).to(norm_hidden_states.device, non_blocking=True) - attn_encoder_hidden_states = ( - self.cached_encoder_hidden_states[1][:B] + - (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B]) - * 0.3 - ).to(norm_hidden_states.device, non_blocking=True) - else: - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mode=self.attention_mode, - ) - if fastercache_counter == fastercache_start_step: - self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] - self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)] - elif fastercache_counter > fastercache_start_step: - self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) - self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) - else: - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mode=self.attention_mode, - ) - - hidden_states = hidden_states + gate_msa * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( - hidden_states, encoder_hidden_states, temb - ) - - # feed-forward - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - ff_output = self.ff(norm_hidden_states) - - hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] - - return hidden_states, encoder_hidden_states - - -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): - """ - A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). - - Parameters: - num_attention_heads (`int`, defaults to `30`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, defaults to `64`): - The number of channels in each head. - in_channels (`int`, defaults to `16`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `16`): - The number of channels in the output. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - time_embed_dim (`int`, defaults to `512`): - Output dimension of timestep embeddings. - text_embed_dim (`int`, defaults to `4096`): - Input dimension of text embeddings from the text encoder. - num_layers (`int`, defaults to `30`): - The number of layers of Transformer blocks to use. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. - sample_width (`int`, defaults to `90`): - The width of the input latents. - sample_height (`int`, defaults to `60`): - The height of the input latents. - sample_frames (`int`, defaults to `49`): - The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 - instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, - but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with - K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). - patch_size (`int`, defaults to `2`): - The size of the patches to use in the patch embedding layer. - temporal_compression_ratio (`int`, defaults to `4`): - The compression ratio across the temporal dimension. See documentation for `sample_frames`. - max_text_seq_length (`int`, defaults to `226`): - The maximum sequence length of the input text embeddings. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to use in feed-forward. - timestep_activation_fn (`str`, defaults to `"silu"`): - Activation function to use when generating the timestep embeddings. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, defaults to `1e-5`): - The epsilon value to use in normalization layers. - spatial_interpolation_scale (`float`, defaults to `1.875`): - Scaling factor to apply in 3D positional embeddings across spatial dimensions. - temporal_interpolation_scale (`float`, defaults to `1.0`): - Scaling factor to apply in 3D positional embeddings across temporal dimensions. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 64, - in_channels: int = 16, - out_channels: Optional[int] = 16, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - time_embed_dim: int = 512, - text_embed_dim: int = 4096, - num_layers: int = 30, - dropout: float = 0.0, - attention_bias: bool = True, - sample_width: int = 90, - sample_height: int = 60, - sample_frames: int = 49, - patch_size: int = 2, - temporal_compression_ratio: int = 4, - max_text_seq_length: int = 226, - activation_fn: str = "gelu-approximate", - timestep_activation_fn: str = "silu", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - spatial_interpolation_scale: float = 1.875, - temporal_interpolation_scale: float = 1.0, - use_rotary_positional_embeddings: bool = False, - add_noise_in_inpaint_model: bool = False, - attention_mode: Optional[str] = None, - ): - super().__init__() - inner_dim = num_attention_heads * attention_head_dim - - post_patch_height = sample_height // patch_size - post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 - self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames - self.post_patch_height = post_patch_height - self.post_patch_width = post_patch_width - self.post_time_compression_frames = post_time_compression_frames - self.patch_size = patch_size - - # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) - self.embedding_dropout = nn.Dropout(dropout) - - # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, - ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) - self.register_buffer("pos_embedding", pos_embedding, persistent=False) - - # 3. Time embeddings - self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) - self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - - # 4. Define spatio-temporal transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - CogVideoXBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - time_embed_dim=time_embed_dim, - dropout=dropout, - activation_fn=activation_fn, - attention_bias=attention_bias, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for _ in range(num_layers) - ] - ) - self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - - # 5. Output blocks - self.norm_out = AdaLayerNorm( - embedding_dim=time_embed_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - chunk_dim=1, - ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) - - self.gradient_checkpointing = False - - self.fuser_list = None - - self.use_fastercache = False - self.fastercache_counter = 0 - self.fastercache_start_step = 15 - self.fastercache_lf_step = 40 - self.fastercache_hf_step = 30 - self.fastercache_device = "cuda" - self.fastercache_num_blocks_to_cache = len(self.transformer_blocks) - self.attention_mode = attention_mode - - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - inpaint_latents: Optional[torch.Tensor] = None, - control_latents: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - video_flow_features: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - batch_size, num_frames, channels, height, width = hidden_states.shape - - # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - # 2. Patch embedding - if inpaint_latents is not None: - hidden_states = torch.concat([hidden_states, inpaint_latents], 2) - if control_latents is not None: - hidden_states = torch.concat([hidden_states, control_latents], 2) - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - - # 3. Position embedding - text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings: - seq_length = height * width * num_frames // (self.config.patch_size**2) - # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] - pos_embeds = self.pos_embedding - emb_size = hidden_states.size()[-1] - pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size) - pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3]) - pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False) - pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size) - pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1) - pos_embeds = pos_embeds[:, : text_seq_length + seq_length] - hidden_states = hidden_states + pos_embeds - hidden_states = self.embedding_dropout(hidden_states) - - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - if self.use_fastercache: - self.fastercache_counter+=1 - if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: - # 4. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states[:1], - encoder_hidden_states=encoder_hidden_states[:1], - temb=emb[:1], - image_rotary_emb=image_rotary_emb, - video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, - fuser = self.fuser_list[i] if self.fuser_list is not None else None, - block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, - fastercache_start_step = self.fastercache_start_step, - fastercache_counter = self.fastercache_counter, - fastercache_device = self.fastercache_device - ) - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 5. Final block - hidden_states = self.norm_out(hidden_states, temb=emb[:1]) - hidden_states = self.proj_out(hidden_states) - - # 6. Unpatchify - p = self.config.patch_size - output = hidden_states.reshape(1, num_frames, height // p, width // p, channels, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - (bb, tt, cc, hh, ww) = output.shape - cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) - lf_c, hf_c = fft(cond.float()) - #lf_step = 40 - #hf_step = 30 - if self.fastercache_counter <= self.fastercache_lf_step: - self.delta_lf = self.delta_lf * 1.1 - if self.fastercache_counter >= self.fastercache_hf_step: - self.delta_hf = self.delta_hf * 1.1 - - new_hf_uc = self.delta_hf + hf_c - new_lf_uc = self.delta_lf + lf_c - - combine_uc = new_lf_uc + new_hf_uc - combined_fft = torch.fft.ifftshift(combine_uc) - recovered_uncond = torch.fft.ifft2(combined_fft).real - recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww) - output = torch.cat([output, recovered_uncond]) - else: - # 4. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, - fuser = self.fuser_list[i] if self.fuser_list is not None else None, - block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, - fastercache_counter = self.fastercache_counter, - fastercache_start_step = self.fastercache_start_step, - fastercache_device = self.fastercache_device - ) - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 5. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 6. Unpatchify - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - if self.fastercache_counter >= self.fastercache_start_step + 1: - (bb, tt, cc, hh, ww) = output.shape - cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) - uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) - - lf_c, hf_c = fft(cond) - lf_uc, hf_uc = fft(uncond) - - self.delta_lf = lf_uc - lf_c - self.delta_hf = hf_uc - hf_c - - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - - @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}): - if subfolder is not None: - pretrained_model_path = os.path.join(pretrained_model_path, subfolder) - print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") - - config_file = os.path.join(pretrained_model_path, 'config.json') - if not os.path.isfile(config_file): - raise RuntimeError(f"{config_file} does not exist") - with open(config_file, "r") as f: - config = json.load(f) - - from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **transformer_additional_kwargs) - model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) - model_file_safetensors = model_file.replace(".bin", ".safetensors") - if os.path.exists(model_file): - state_dict = torch.load(model_file, map_location="cpu") - elif os.path.exists(model_file_safetensors): - from safetensors.torch import load_file, safe_open - state_dict = load_file(model_file_safetensors) - else: - from safetensors.torch import load_file, safe_open - model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) - state_dict = {} - for model_file_safetensors in model_files_safetensors: - _state_dict = load_file(model_file_safetensors) - for key in _state_dict: - state_dict[key] = _state_dict[key] - - if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size(): - new_shape = model.state_dict()['patch_embed.proj.weight'].size() - if len(new_shape) == 5: - state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone() - state_dict['patch_embed.proj.weight'][:, :, :-1] = 0 - else: - if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: - model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight'] - model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0 - state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] - else: - model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :] - state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] - - tmp_state_dict = {} - for key in state_dict: - if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): - tmp_state_dict[key] = state_dict[key] - else: - print(key, "Size don't match, skip") - state_dict = tmp_state_dict - - m, u = model.load_state_dict(state_dict, strict=False) - print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") - print(m) - - params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()] - print(f"### Mamba Parameters: {sum(params) / 1e6} M") - - params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] - print(f"### attn1 Parameters: {sum(params) / 1e6} M") - - return model \ No newline at end of file diff --git a/cogvideox_fun/utils.py b/cogvideox_fun/utils.py index f790161..9f670ec 100644 --- a/cogvideox_fun/utils.py +++ b/cogvideox_fun/utils.py @@ -1,26 +1,6 @@ -import os -import gc import numpy as np -import torch from PIL import Image -# Copyright (c) OpenMMLab. All rights reserved. - -def tensor2pil(image): - return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) - -def numpy2pil(image): - return Image.fromarray(np.clip(255. * image, 0, 255).astype(np.uint8)) - -def to_pil(image): - if isinstance(image, Image.Image): - return image - if isinstance(image, torch.Tensor): - return tensor2pil(image) - if isinstance(image, np.ndarray): - return numpy2pil(image) - raise ValueError(f"Cannot convert {type(image)} to PIL.Image") - ASPECT_RATIO_512 = { '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], @@ -54,126 +34,10 @@ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_5 closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) return ratios[closest_ratio], float(closest_ratio) - def get_width_and_height_from_image_and_base_resolution(image, base_resolution): target_pixels = int(base_resolution) * int(base_resolution) original_width, original_height = Image.open(image).size ratio = (target_pixels / (original_width * original_height)) ** 0.5 width_slider = round(original_width * ratio) height_slider = round(original_height * ratio) - return height_slider, width_slider - -def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): - if validation_image_start is not None and validation_image_end is not None: - if type(validation_image_start) is str and os.path.isfile(validation_image_start): - image_start = clip_image = Image.open(validation_image_start).convert("RGB") - image_start = image_start.resize([sample_size[1], sample_size[0]]) - clip_image = clip_image.resize([sample_size[1], sample_size[0]]) - else: - image_start = clip_image = validation_image_start - image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] - clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] - - if type(validation_image_end) is str and os.path.isfile(validation_image_end): - image_end = Image.open(validation_image_end).convert("RGB") - image_end = image_end.resize([sample_size[1], sample_size[0]]) - else: - image_end = validation_image_end - image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end] - - if type(image_start) is list: - clip_image = clip_image[0] - start_video = torch.cat( - [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], - dim=2 - ) - input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) - input_video[:, :, :len(image_start)] = start_video - - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, len(image_start):] = 255 - else: - input_video = torch.tile( - torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), - [1, 1, video_length, 1, 1] - ) - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, 1:] = 255 - - if type(image_end) is list: - image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end] - end_video = torch.cat( - [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end], - dim=2 - ) - input_video[:, :, -len(end_video):] = end_video - - input_video_mask[:, :, -len(image_end):] = 0 - else: - image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) - input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) - input_video_mask[:, :, -1:] = 0 - - input_video = input_video / 255 - - elif validation_image_start is not None: - if type(validation_image_start) is str and os.path.isfile(validation_image_start): - image_start = clip_image = Image.open(validation_image_start).convert("RGB") - image_start = image_start.resize([sample_size[1], sample_size[0]]) - clip_image = clip_image.resize([sample_size[1], sample_size[0]]) - else: - image_start = clip_image = validation_image_start - image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] - clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] - image_end = None - - if type(image_start) is list: - clip_image = clip_image[0] - start_video = torch.cat( - [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], - dim=2 - ) - input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) - input_video[:, :, :len(image_start)] = start_video - input_video = input_video / 255 - - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, len(image_start):] = 255 - else: - input_video = torch.tile( - torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), - [1, 1, video_length, 1, 1] - ) / 255 - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, 1:, ] = 255 - else: - image_start = None - image_end = None - input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]]) - input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255 - clip_image = None - - del image_start - del image_end - gc.collect() - - return input_video, input_video_mask, clip_image - -def get_video_to_video_latent(input_video_path, video_length, sample_size, validation_video_mask=None): - input_video = input_video_path - - input_video = torch.from_numpy(np.array(input_video))[:video_length] - input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 - - if validation_video_mask is not None: - validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) - input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) - - input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) - input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) - input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) - else: - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, :] = 255 - - return input_video, input_video_mask, None \ No newline at end of file + return height_slider, width_slider \ No newline at end of file diff --git a/cogvideox_fun/context.py b/context.py similarity index 100% rename from cogvideox_fun/context.py rename to context.py diff --git a/convert_weight_sat2hf.py b/convert_weight_sat2hf.py deleted file mode 100644 index 545925b..0000000 --- a/convert_weight_sat2hf.py +++ /dev/null @@ -1,303 +0,0 @@ -""" - -The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format. -This script supports the conversion of the following models: -- CogVideoX-2B -- CogVideoX-5B, CogVideoX-5B-I2V -- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V - -Original Script: -https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py - -""" -import argparse -from typing import Any, Dict - -import torch -from transformers import T5EncoderModel, T5Tokenizer - -from diffusers import ( - AutoencoderKLCogVideoX, - CogVideoXDDIMScheduler, - CogVideoXImageToVideoPipeline, - CogVideoXPipeline, - #CogVideoXTransformer3DModel, -) -from custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel - - -def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): - to_q_key = key.replace("query_key_value", "to_q") - to_k_key = key.replace("query_key_value", "to_k") - to_v_key = key.replace("query_key_value", "to_v") - to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) - state_dict[to_q_key] = to_q - state_dict[to_k_key] = to_k - state_dict[to_v_key] = to_v - state_dict.pop(key) - - -def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): - layer_id, weight_or_bias = key.split(".")[-2:] - - if "query" in key: - new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" - elif "key" in key: - new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" - - state_dict[new_key] = state_dict.pop(key) - - -def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): - layer_id, _, weight_or_bias = key.split(".")[-3:] - - weights_or_biases = state_dict[key].chunk(12, dim=0) - norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) - norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) - - norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" - state_dict[norm1_key] = norm1_weights_or_biases - - norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" - state_dict[norm2_key] = norm2_weights_or_biases - - state_dict.pop(key) - - -def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): - state_dict.pop(key) - - -def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): - key_split = key.split(".") - layer_index = int(key_split[2]) - replace_layer_index = 4 - 1 - layer_index - - key_split[1] = "up_blocks" - key_split[2] = str(replace_layer_index) - new_key = ".".join(key_split) - - state_dict[new_key] = state_dict.pop(key) - - -TRANSFORMER_KEYS_RENAME_DICT = { - "transformer.final_layernorm": "norm_final", - "transformer": "transformer_blocks", - "attention": "attn1", - "mlp": "ff.net", - "dense_h_to_4h": "0.proj", - "dense_4h_to_h": "2", - ".layers": "", - "dense": "to_out.0", - "input_layernorm": "norm1.norm", - "post_attn1_layernorm": "norm2.norm", - "time_embed.0": "time_embedding.linear_1", - "time_embed.2": "time_embedding.linear_2", - "mixins.patch_embed": "patch_embed", - "mixins.final_layer.norm_final": "norm_out.norm", - "mixins.final_layer.linear": "proj_out", - "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", - "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V -} - -TRANSFORMER_SPECIAL_KEYS_REMAP = { - "query_key_value": reassign_query_key_value_inplace, - "query_layernorm_list": reassign_query_key_layernorm_inplace, - "key_layernorm_list": reassign_query_key_layernorm_inplace, - "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, - "embed_tokens": remove_keys_inplace, - "freqs_sin": remove_keys_inplace, - "freqs_cos": remove_keys_inplace, - "position_embedding": remove_keys_inplace, -} - -VAE_KEYS_RENAME_DICT = { - "block.": "resnets.", - "down.": "down_blocks.", - "downsample": "downsamplers.0", - "upsample": "upsamplers.0", - "nin_shortcut": "conv_shortcut", - "encoder.mid.block_1": "encoder.mid_block.resnets.0", - "encoder.mid.block_2": "encoder.mid_block.resnets.1", - "decoder.mid.block_1": "decoder.mid_block.resnets.0", - "decoder.mid.block_2": "decoder.mid_block.resnets.1", -} - -VAE_SPECIAL_KEYS_REMAP = { - "loss": remove_keys_inplace, - "up.": replace_up_keys_inplace, -} - -TOKENIZER_MAX_LENGTH = 226 - - -def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: - state_dict = saved_dict - if "model" in saved_dict.keys(): - state_dict = state_dict["model"] - if "module" in saved_dict.keys(): - state_dict = state_dict["module"] - if "state_dict" in saved_dict.keys(): - state_dict = state_dict["state_dict"] - return state_dict - - -def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: - state_dict[new_key] = state_dict.pop(old_key) - - -def convert_transformer( - ckpt_path: str, - num_layers: int, - num_attention_heads: int, - use_rotary_positional_embeddings: bool, - i2v: bool, - dtype: torch.dtype, -): - PREFIX_KEY = "model.diffusion_model." - - original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - transformer = CogVideoXTransformer3DModel( - in_channels=32 if i2v else 16, - num_layers=num_layers, - num_attention_heads=num_attention_heads, - use_rotary_positional_embeddings=use_rotary_positional_embeddings, - use_learned_positional_embeddings=i2v, - ).to(dtype=dtype) - - for key in list(original_state_dict.keys()): - new_key = key[len(PREFIX_KEY):] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_state_dict_inplace(original_state_dict, key, new_key) - - for key in list(original_state_dict.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, original_state_dict) - transformer.load_state_dict(original_state_dict, strict=True) - return transformer - - -def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): - original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) - - for key in list(original_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_state_dict_inplace(original_state_dict, key, new_key) - - for key in list(original_state_dict.keys()): - for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, original_state_dict) - - vae.load_state_dict(original_state_dict, strict=True) - return vae - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" - ) - parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") - parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") - parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") - parser.add_argument( - "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" - ) - parser.add_argument( - "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" - ) - # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 - parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") - # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 - parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") - # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True - parser.add_argument( - "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" - ) - # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 - parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") - # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 - parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") - parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - - transformer = None - vae = None - - if args.fp16 and args.bf16: - raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") - - dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 - - if args.transformer_ckpt_path is not None: - transformer = convert_transformer( - args.transformer_ckpt_path, - args.num_layers, - args.num_attention_heads, - args.use_rotary_positional_embeddings, - args.i2v, - dtype, - ) - if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) - - #text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl" - #tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) - #text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) - - # Apparently, the conversion does not work anymore without this :shrug: - #for param in text_encoder.parameters(): - # param.data = param.data.contiguous() - - scheduler = CogVideoXDDIMScheduler.from_config( - { - "snr_shift_scale": args.snr_shift_scale, - "beta_end": 0.012, - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "clip_sample": False, - "num_train_timesteps": 1000, - "prediction_type": "v_prediction", - "rescale_betas_zero_snr": True, - "set_alpha_to_one": True, - "timestep_spacing": "trailing", - } - ) - if args.i2v: - pipeline_cls = CogVideoXImageToVideoPipeline - else: - pipeline_cls = CogVideoXPipeline - - pipe = pipeline_cls( - tokenizer=None, - text_encoder=None, - vae=vae, - transformer=transformer, - scheduler=scheduler, - ) - - if args.fp16: - pipe = pipe.to(dtype=torch.float16) - if args.bf16: - pipe = pipe.to(dtype=torch.bfloat16) - - # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird - # for users to specify variant when the default is not fp32 and they want to run with the correct default (which - # is either fp16/bf16 here). - - # This is necessary This is necessary for users with insufficient memory, - # such as those using Colab and notebooks, as it can save some memory used for model loading. - pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 50b0f25..89c72aa 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -76,7 +76,6 @@ class CogVideoXAttnProcessor2_0: if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - #@torch.compiler.disable() def __call__( self, attn: Attention, diff --git a/model_loading.py b/model_loading.py index e627351..c77e3c5 100644 --- a/model_loading.py +++ b/model_loading.py @@ -43,11 +43,8 @@ from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel from .pipeline_cogvideox import CogVideoXPipeline from contextlib import nullcontext -from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun -from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun - -from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint -from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device from .utils import remove_specific_blocks, log from comfy.utils import load_torch_file @@ -121,8 +118,7 @@ class DownloadAndLoadCogVideoModel: "precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} ), - "fp8_transformer": (['disabled', 'enabled', 'fastmode', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}), - "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), + "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fastmode', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), @@ -132,13 +128,13 @@ class DownloadAndLoadCogVideoModel: } } - RETURN_TYPES = ("COGVIDEOPIPE",) - RETURN_NAMES = ("cogvideo_pipe", ) + RETURN_TYPES = ("COGVIDEOMODEL", "VAE",) + RETURN_NAMES = ("model", "vae", ) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'" - def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", + def loadmodel(self, model, precision, quantization="disabled", compile="disabled", enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None, attention_mode="sdpa", load_device="main_device"): @@ -215,12 +211,7 @@ class DownloadAndLoadCogVideoModel: local_dir_use_symlinks=False, ) - #transformer - if "Fun" in model: - transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder) - else: - transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) - + transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) transformer = transformer.to(dtype).to(transformer_load_device) if "1.5" in model: @@ -235,17 +226,17 @@ class DownloadAndLoadCogVideoModel: scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) # VAE - if "Fun" in model: - vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) - if "Pose" in model: - pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler) - else: - pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) - else: - vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) - pipe = CogVideoXPipeline(vae, transformer, scheduler) - if "cogvideox-2b-img2vid" in model: - pipe.input_with_padding = False + vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) + + #pipeline + pipe = CogVideoXPipeline( + transformer, + scheduler, + dtype=dtype, + is_fun_inpaint=True if "fun" in model.lower() and "pose" not in model.lower() else False + ) + if "cogvideox-2b-img2vid" in model: + pipe.input_with_padding = False #LoRAs if lora is not None: @@ -281,8 +272,19 @@ class DownloadAndLoadCogVideoModel: lora_scale = lora_scale / lora_rank pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) + if "fused" in attention_mode: + from diffusers.models.attention import Attention + transformer.fuse_qkv_projections = True + for module in transformer.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + transformer.attention_mode = attention_mode + + if compile_args is not None: + pipe.transformer.to(memory_format=torch.channels_last) + #fp8 - if fp8_transformer == "enabled" or fp8_transformer == "fastmode": + if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode": params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"} if "1.5" in model: params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"}) @@ -290,13 +292,20 @@ class DownloadAndLoadCogVideoModel: if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(torch.float8_e4m3fn) - if fp8_transformer == "fastmode": + if quantization == "fp8_e4m3fn_fastmode": from .fp8_optimization import convert_fp8_linear if "1.5" in model: params_to_keep.update({"ff"}) #otherwise NaNs convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep) + + # compilation + if compile_args is not None: + torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] + for i, block in enumerate(pipe.transformer.transformer_blocks): + if "CogVideoXBlock" in str(block): + pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) - elif "torchao" in fp8_transformer: + if "torchao" in quantization: try: from torchao.quantization import ( quantize_, @@ -313,14 +322,14 @@ class DownloadAndLoadCogVideoModel: return isinstance(module, nn.Linear) return False - if "fp6" in fp8_transformer: #slower for some reason on 4090 + if "fp6" in quantization: #slower for some reason on 4090 quant_func = fpx_weight_only(3, 2) - elif "fp8dq" in fp8_transformer: #very fast on 4090 when compiled + elif "fp8dq" in quantization: #very fast on 4090 when compiled quant_func = float8_dynamic_activation_float8_weight() - elif 'fp8dqrow' in fp8_transformer: + elif 'fp8dqrow' in quantization: from torchao.quantization.quant_api import PerRow quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow()) - elif 'int8dq' in fp8_transformer: + elif 'int8dq' in quantization: quant_func = int8_dynamic_activation_int8_weight() for i, block in enumerate(pipe.transformer.transformer_blocks): @@ -365,41 +374,19 @@ class DownloadAndLoadCogVideoModel: # (3): Dropout(p=0.0, inplace=False) # ) # ) - # ) + # ) + + # if compile == "onediff": + # from onediffx import compile_pipe + # os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' - # compilation - if compile == "torch": - #pipe.transformer.to(memory_format=torch.channels_last) - if compile_args is not None: - torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] - for i, block in enumerate(pipe.transformer.transformer_blocks): - if "CogVideoXBlock" in str(block): - pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) - else: - for i, block in enumerate(pipe.transformer.transformer_blocks): - if "CogVideoXBlock" in str(block): - pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") - - transformer.attention_mode = attention_mode - - if "fused" in attention_mode: - from diffusers.models.attention import Attention - transformer.fuse_qkv_projections = True - for module in transformer.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - elif compile == "onediff": - from onediffx import compile_pipe - os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' - - pipe = compile_pipe( - pipe, - backend="nexfort", - options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}}, - ignores=["vae"], - fuse_qkv_projections= False, - ) + # pipe = compile_pipe( + # pipe, + # backend="nexfort", + # options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}}, + # ignores=["vae"], + # fuse_qkv_projections= False, + # ) pipeline = { "pipe": pipe, @@ -412,7 +399,7 @@ class DownloadAndLoadCogVideoModel: "model_name": model, } - return (pipeline,) + return (pipeline, vae) #region GGUF class DownloadAndLoadCogVideoGGUFModel: @classmethod @@ -444,8 +431,8 @@ class DownloadAndLoadCogVideoGGUFModel: } } - RETURN_TYPES = ("COGVIDEOPIPE",) - RETURN_NAMES = ("cogvideo_pipe", ) + RETURN_TYPES = ("COGVIDEOMODEL", "VAE",) + RETURN_NAMES = ("model", "vae",) FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" @@ -486,7 +473,6 @@ class DownloadAndLoadCogVideoGGUFModel: with open(transformer_path) as f: transformer_config = json.load(f) - from . import mz_gguf_loader import importlib @@ -498,7 +484,6 @@ class DownloadAndLoadCogVideoGGUFModel: transformer_config["in_channels"] = 32 else: transformer_config["in_channels"] = 33 - transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config) elif "I2V" in model or "Interpolation" in model: transformer_config["in_channels"] = 32 if "1_5" in model: @@ -508,10 +493,10 @@ class DownloadAndLoadCogVideoGGUFModel: transformer_config["patch_bias"] = False transformer_config["sample_height"] = 300 transformer_config["sample_width"] = 300 - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) else: transformer_config["in_channels"] = 16 - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + + transformer = CogVideoXTransformer3DModel.from_config(transformer_config) params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"} if "2b" in model: @@ -564,60 +549,25 @@ class DownloadAndLoadCogVideoGGUFModel: with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f: vae_config = json.load(f) + #VAE vae_sd = load_torch_file(vae_path) - if "fun" in model: - vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device) - vae.load_state_dict(vae_sd) - if "Pose" in model: - pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler) - else: - pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) - else: - vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device) - vae.load_state_dict(vae_sd) - pipe = CogVideoXPipeline(vae, transformer, scheduler) + vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device) + vae.load_state_dict(vae_sd) + del vae_sd + pipe = CogVideoXPipeline(transformer, scheduler, dtype=vae_dtype) if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() sd = load_torch_file(gguf_path) - - # #LoRAs - # if lora is not None: - # if "fun" in model.lower(): - # raise NotImplementedError("LoRA with GGUF is not supported for Fun models") - # from .lora_utils import merge_lora#, load_lora_into_transformer - # #for l in lora: - # # log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") - # # pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"]) - # else: - # adapter_list = [] - # adapter_weights = [] - # for l in lora: - # lora_sd = load_torch_file(l["path"]) - # for key, val in lora_sd.items(): - # if "lora_B" in key: - # lora_rank = val.shape[1] - # break - # log.info(f"Loading rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") - # adapter_name = l['path'].split("/")[-1].split(".")[0] - # adapter_weight = l['strength'] - # pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) - - # #transformer = load_lora_into_transformer(lora, transformer) - # adapter_list.append(adapter_name) - # adapter_weights.append(adapter_weight) - # for l in lora: - # pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) - # #pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"]) - pipe.transformer = mz_gguf_loader.quantize_load_state_dict(pipe.transformer, sd, device="cpu") + del sd + if load_device == "offload_device": pipe.transformer.to(offload_device) else: pipe.transformer.to(device) - pipeline = { "pipe": pipe, "dtype": vae_dtype, @@ -629,9 +579,253 @@ class DownloadAndLoadCogVideoGGUFModel: "manual_offloading": True, } + return (pipeline, vae) + +#region ModelLoader +class CogVideoXModelLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}), + + "base_precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}), + "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "optional quantization method"}), + "load_device": (["main_device", "offload_device"], {"default": "main_device"}), + "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), + }, + "optional": { + "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), + "lora": ("COGLORA", {"default": None}), + "compile_args":("COMPILEARGS", ), + "attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}), + } + } + + RETURN_TYPES = ("COGVIDEOMODEL",) + RETURN_NAMES = ("model", ) + FUNCTION = "loadmodel" + CATEGORY = "CogVideoWrapper" + + def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload, + block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"): + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + manual_offloading = True + transformer_load_device = device if load_device == "main_device" else offload_device + mm.soft_empty_cache() + + base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[base_precision] + + model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) + sd = load_torch_file(model_path, device=transformer_load_device) + + model_type = "" + if sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2): + model_type = "fun_5b" + elif sd["patch_embed.proj.weight"].shape == (3072, 16, 2, 2): + model_type = "5b" + elif sd["patch_embed.proj.weight"].shape == (3072, 128): + model_type = "5b_1_5" + elif sd["patch_embed.proj.weight"].shape == (3072, 256): + model_type = "5b_I2V_1_5" + elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2): + model_type = "fun_2b" + elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2): + model_type = "2b" + elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2): + if "pos_embedding" in sd: + model_type = "fun_5b_pose" + else: + model_type = "I2V_5b" + else: + raise Exception("Selected model is not recognized") + log.info(f"Detected CogVideoX model type: {model_type}") + + if "5b" in model_type: + scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') + transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_5b.json') + elif "2b" in model_type: + scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') + transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_2b.json') + + with open(transformer_config_path) as f: + transformer_config = json.load(f) + + with init_empty_weights(): + if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]: + transformer_config["in_channels"] = 32 + if "1_5" in model_type: + transformer_config["ofs_embed_dim"] = 512 + transformer_config["use_learned_positional_embeddings"] = False + transformer_config["patch_size_t"] = 2 + transformer_config["patch_bias"] = False + transformer_config["sample_height"] = 300 + transformer_config["sample_width"] = 300 + elif "fun" in model_type: + transformer_config["in_channels"] = 33 + else: + if "1_5" in model_type: + transformer_config["use_learned_positional_embeddings"] = False + transformer_config["patch_size_t"] = 2 + transformer_config["patch_bias"] = False + #transformer_config["sample_height"] = 300 todo: check if this is needed + #transformer_config["sample_width"] = 300 + transformer_config["in_channels"] = 16 + + transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + + #load weights + #params_to_keep = {} + log.info("Using accelerate to load and assign model weights to device...") + + for name, param in transformer.named_parameters(): + #dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype + set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name]) + del sd + + + #scheduler + with open(scheduler_config_path) as f: + scheduler_config = json.load(f) + scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config, subfolder="scheduler") + + if block_edit is not None: + transformer = remove_specific_blocks(transformer, block_edit) + + if "fused" in attention_mode: + from diffusers.models.attention import Attention + transformer.fuse_qkv_projections = True + for module in transformer.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + transformer.attention_mode = attention_mode + + if "fun" in model_type: + if not "pose" in model_type: + raise NotImplementedError("Fun models besides pose are not supported with this loader yet") + pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) + else: + pipe = CogVideoXPipeline(transformer, scheduler, dtype=base_dtype) + else: + pipe = CogVideoXPipeline(transformer, scheduler, dtype=base_dtype) + + if enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload() + + #LoRAs + if lora is not None: + from .lora_utils import merge_lora#, load_lora_into_transformer + if "fun" in model.lower(): + for l in lora: + log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") + transformer = merge_lora(transformer, l["path"], l["strength"]) + else: + adapter_list = [] + adapter_weights = [] + for l in lora: + fuse = True if l["fuse_lora"] else False + lora_sd = load_torch_file(l["path"]) + for key, val in lora_sd.items(): + if "lora_B" in key: + lora_rank = val.shape[1] + break + log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") + adapter_name = l['path'].split("/")[-1].split(".")[0] + adapter_weight = l['strength'] + pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) + + #transformer = load_lora_into_transformer(lora, transformer) + adapter_list.append(adapter_name) + adapter_weights.append(adapter_weight) + for l in lora: + pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) + if fuse: + lora_scale = 1 + dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling + if any(item in lora[-1]["path"].lower() for item in dimension_loras): + lora_scale = lora_scale / lora_rank + pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) + + if compile_args is not None: + pipe.transformer.to(memory_format=torch.channels_last) + + #quantization + if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast": + params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"} + if "1.5" in model: + params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"}) + for name, param in pipe.transformer.named_parameters(): + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(torch.float8_e4m3fn) + + if quantization == "fp8_e4m3fn_fast": + from .fp8_optimization import convert_fp8_linear + if "1.5" in model: + params_to_keep.update({"ff"}) #otherwise NaNs + convert_fp8_linear(pipe.transformer, base_dtype, params_to_keep=params_to_keep) + + #compile + if compile_args is not None: + torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] + for i, block in enumerate(pipe.transformer.transformer_blocks): + if "CogVideoXBlock" in str(block): + pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) + + if "torchao" in quantization: + try: + from torchao.quantization import ( + quantize_, + fpx_weight_only, + float8_dynamic_activation_float8_weight, + int8_dynamic_activation_int8_weight + ) + except: + raise ImportError("torchao is not installed, please install torchao to use fp8dq") + + def filter_fn(module: nn.Module, fqn: str) -> bool: + target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models + if any(sub in fqn for sub in target_submodules): + return isinstance(module, nn.Linear) + return False + + if "fp6" in quantization: #slower for some reason on 4090 + quant_func = fpx_weight_only(3, 2) + elif "fp8dq" in quantization: #very fast on 4090 when compiled + quant_func = float8_dynamic_activation_float8_weight() + elif 'fp8dqrow' in quantization: + from torchao.quantization.quant_api import PerRow + quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow()) + elif 'int8dq' in quantization: + quant_func = int8_dynamic_activation_int8_weight() + + for i, block in enumerate(pipe.transformer.transformer_blocks): + if "CogVideoXBlock" in str(block): + quantize_(block, quant_func, filter_fn=filter_fn) + + manual_offloading = False # to disable manual .to(device) calls + log.info(f"Quantized transformer blocks to {quantization}") + + # if load_device == "offload_device": + # pipe.transformer.to(offload_device) + # else: + # pipe.transformer.to(device) + + pipeline = { + "pipe": pipe, + "dtype": base_dtype, + "base_path": model, + "onediff": False, + "cpu_offloading": enable_sequential_cpu_offload, + "scheduler_config": scheduler_config, + "model_name": model, + "manual_offloading": manual_offloading, + } + return (pipeline,) -#revion VAE +#region VAE class CogVideoXVAELoader: @classmethod @@ -829,6 +1023,7 @@ NODE_CLASS_MAPPINGS = { "DownloadAndLoadToraModel": DownloadAndLoadToraModel, "CogVideoLoraSelect": CogVideoLoraSelect, "CogVideoXVAELoader": CogVideoXVAELoader, + "CogVideoXModelLoader": CogVideoXModelLoader, } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -837,4 +1032,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadToraModel": "(Down)load Tora Model", "CogVideoLoraSelect": "CogVideo LoraSelect", "CogVideoXVAELoader": "CogVideoX VAE Loader", + "CogVideoXModelLoader": "CogVideoX Model Loader", } \ No newline at end of file diff --git a/nodes.py b/nodes.py index b18a978..bd89609 100644 --- a/nodes.py +++ b/nodes.py @@ -1,7 +1,6 @@ import os import torch -import folder_paths -import comfy.model_management as mm +import json from einops import rearrange from contextlib import nullcontext @@ -38,11 +37,10 @@ scheduler_mapping = { } available_schedulers = list(scheduler_mapping.keys()) -from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil +from diffusers.video_processor import VideoProcessor -from PIL import Image -import numpy as np -import json +import folder_paths +import comfy.model_management as mm script_directory = os.path.dirname(os.path.abspath(__file__)) @@ -129,94 +127,6 @@ class CogVideoXTorchCompileSettings: return (compile_args, ) #region TextEncode -class CogVideoEncodePrompt: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "pipeline": ("COGVIDEOPIPE",), - "prompt": ("STRING", {"default": "", "multiline": True} ), - "negative_prompt": ("STRING", {"default": "", "multiline": True} ), - } - } - - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - FUNCTION = "process" - CATEGORY = "CogVideoWrapper" - - def process(self, pipeline, prompt, negative_prompt): - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() - pipe = pipeline["pipe"] - dtype = pipeline["dtype"] - - pipe.text_encoder.to(device) - pipe.transformer.to(offload_device) - - positive, negative = pipe.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=True, - num_videos_per_prompt=1, - max_sequence_length=226, - device=device, - dtype=dtype, - ) - pipe.text_encoder.to(offload_device) - - return (positive, negative) - -# Inject clip_l and t5xxl w/ individual strength adjustments for ComfyUI's DualCLIPLoader node for CogVideoX. Use CLIPSave node from any SDXL model then load in a custom clip_l model. -# For some reason seems to give a lot more movement and consistency on new CogVideoXFun img2vid? set 'type' to flux / DualClipLoader. -class CogVideoDualTextEncode_311: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "clip": ("CLIP",), - "clip_l": ("STRING", {"default": "", "multiline": True}), - "t5xxl": ("STRING", {"default": "", "multiline": True}), - "clip_l_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), # excessive max for testing, have found intesting results up to 20 max? - "t5xxl_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), # setting this to 0.0001 or level as high as 18 seems to work. - } - } - - RETURN_TYPES = ("CONDITIONING",) - RETURN_NAMES = ("conditioning",) - FUNCTION = "process" - CATEGORY = "CogVideoWrapper" - - def process(self, clip, clip_l, t5xxl, clip_l_strength, t5xxl_strength): - load_device = mm.text_encoder_device() - offload_device = mm.text_encoder_offload_device() - - # setup tokenizer for clip_l and t5xxl - clip.tokenizer.t5xxl.pad_to_max_length = True - clip.tokenizer.t5xxl.max_length = 226 - clip.cond_stage_model.to(load_device) - - # tokenize clip_l and t5xxl - tokens_l = clip.tokenize(clip_l, return_word_ids=True) - tokens_t5 = clip.tokenize(t5xxl, return_word_ids=True) - - # encode the tokens separately - embeds_l = clip.encode_from_tokens(tokens_l, return_pooled=False, return_dict=False) - embeds_t5 = clip.encode_from_tokens(tokens_t5, return_pooled=False, return_dict=False) - - # apply strength adjustments to each embedding - if embeds_l.dim() == 3: - embeds_l *= clip_l_strength - if embeds_t5.dim() == 3: - embeds_t5 *= t5xxl_strength - - # combine the embeddings by summing them - combined_embeds = embeds_l + embeds_t5 - - # offload the model to save memory - clip.cond_stage_model.to(offload_device) - - return (combined_embeds,) - class CogVideoTextEncode: @classmethod def INPUT_TYPES(s): @@ -285,20 +195,31 @@ class CogVideoTextEncodeCombine: return (embeds, ) -#region ImageEncode +#region ImageEncode + +def add_noise_to_reference_video(image, ratio=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image + class CogVideoImageEncode: @classmethod def INPUT_TYPES(s): return {"required": { - "pipeline": ("COGVIDEOPIPE",), - "image": ("IMAGE", ), + "vae": ("VAE",), + "start_image": ("IMAGE", ), }, "optional": { - "chunk_size": ("INT", {"default": 16, "min": 4, "tooltip": "How many images to encode at once, lower values use less memory"}), + "end_image": ("IMAGE", ), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), - "mask": ("MASK", ), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}), - "vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}), }, } @@ -307,49 +228,111 @@ class CogVideoImageEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0, vae_override=None): + def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) - - B, H, W, C = image.shape - - vae = pipeline["pipe"].vae if vae_override is None else vae_override - vae.enable_slicing() - model_name = pipeline.get("model_name", "") - - if ("1.5" in model_name or "1_5" in model_name) and image.shape[0] == 1: - vae_scaling_factor = 1 #/ vae.config.scaling_factor - else: - vae_scaling_factor = vae.config.scaling_factor + + try: + vae.enable_slicing() + except: + pass + + vae_scaling_factor = vae.config.scaling_factor if enable_tiling: from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling enable_vae_encode_tiling(vae) - if not pipeline["cpu_offloading"]: - vae.to(device) + vae.to(device) - check_diffusers_version() try: vae._clear_fake_context_parallel_cache() except: pass - input_image = image.clone() - if mask is not None: - pipeline["pipe"].original_mask = mask - # print(mask.shape) - # mask = mask.repeat(B, 1, 1) # Shape: [B, H, W] - # mask = mask.unsqueeze(-1).repeat(1, 1, 1, C) - # print(mask.shape) - # input_image = input_image * (1 -mask) - else: - pipeline["pipe"].original_mask = None - #input_image = input_image.permute(0, 3, 1, 2) # B, C, H, W - #input_image = pipeline["pipe"].video_processor.preprocess(input_image).to(device, dtype=vae.dtype) - #input_image = input_image.unsqueeze(2) + if noise_aug_strength > 0: + start_image = add_noise_to_reference_video(start_image, ratio=noise_aug_strength) + if end_image is not None: + end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength) + + latents_list = [] + start_image = (start_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + start_latents = vae.encode(start_image).latent_dist.sample(generator) + start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W + + if end_image is not None: + end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) + end_latents = vae.encode(end_image).latent_dist.sample(generator) + end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W + latents_list = [start_latents, end_latents] + final_latents = torch.cat(latents_list, dim=1) + else: + final_latents = start_latents + + final_latents = final_latents * vae_scaling_factor + + log.info(f"Encoded latents shape: {final_latents.shape}") + vae.to(offload_device) + + return ({"samples": final_latents}, ) + +class CogVideoImageEncodeFunInP: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("VAE",), + "start_image": ("IMAGE", ), + "num_frames": ("INT", {"default": 49, "min": 2, "max": 1024, "step": 1}), + }, + "optional": { + "end_image": ("IMAGE", ), + "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), + "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}), + }, + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("image_cond_latents",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, vae, start_image, num_frames, end_image=None, enable_tiling=False, noise_aug_strength=0.0): + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + generator = torch.Generator(device=device).manual_seed(0) + + try: + vae.enable_slicing() + except: + pass + + vae_scaling_factor = vae.config.scaling_factor + + if enable_tiling: + from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling + enable_vae_encode_tiling(vae) + + vae.to(device) + + try: + vae._clear_fake_context_parallel_cache() + except: + pass + + if end_image is not None: + # Create a tensor of zeros for padding + padding = torch.zeros((num_frames - 2, start_image.shape[1], start_image.shape[2], 3), device=end_image.device, dtype=end_image.dtype) * -1 + # Concatenate start_image, padding, and end_image + input_image = torch.cat([start_image, padding, end_image], dim=0) + else: + # Create a tensor of zeros for padding + padding = torch.zeros((num_frames - 1, start_image.shape[1], start_image.shape[2], 3), device=start_image.device, dtype=start_image.dtype) * -1 + # Concatenate start_image and padding + input_image = torch.cat([start_image, padding], dim=0) + input_image = input_image * 2.0 - 1.0 input_image = input_image.to(vae.dtype).to(device) input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W @@ -357,120 +340,34 @@ class CogVideoImageEncode: B, C, T, H, W = input_image.shape if noise_aug_strength > 0: input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength) + + bs = 1 + new_mask_pixel_values = [] + print("input_image shape: ",input_image.shape) + for i in range(0, input_image.shape[0], bs): + mask_pixel_values_bs = input_image[i : i + bs] + mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0] + print("mask_pixel_values_bs: ",mask_pixel_values_bs.parameters.shape) + mask_pixel_values_bs = mask_pixel_values_bs.mode() + print("mask_pixel_values_bs: ",mask_pixel_values_bs.shape, mask_pixel_values_bs.min(), mask_pixel_values_bs.max()) + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W - latents_list = [] - # Loop through the temporal dimension in chunks of 16 - for i in range(0, T, chunk_size): - # Get the chunk of 16 frames (or remaining frames if less than 16 are left) - end_index = min(i + chunk_size, T) - image_chunk = input_image[:, :, i:end_index, :, :] # Shape: [B, C, chunk_size, H, W] + mask = torch.zeros_like(masked_image_latents[:, :, :1, :, :]) + if end_image is not None: + mask[:, -1, :, :, :] = vae_scaling_factor + mask[:, 0, :, :, :] = vae_scaling_factor - # Encode the chunk of images - latents = vae.encode(image_chunk) - - sample_mode = "sample" - if hasattr(latents, "latent_dist") and sample_mode == "sample": - latents = latents.latent_dist.sample(generator) - elif hasattr(latents, "latent_dist") and sample_mode == "argmax": - latents = latents.latent_dist.mode() - elif hasattr(latents, "latents"): - latents = latents.latents - - latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W - latents_list.append(latents) - - # Concatenate all the chunks along the temporal dimension - final_latents = torch.cat(latents_list, dim=1) - final_latents = final_latents * vae_scaling_factor + final_latents = masked_image_latents * vae_scaling_factor log.info(f"Encoded latents shape: {final_latents.shape}") - if not pipeline["cpu_offloading"]: - vae.to(offload_device) + vae.to(offload_device) - return ({"samples": final_latents}, ) - -class CogVideoImageInterpolationEncode: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "pipeline": ("COGVIDEOPIPE",), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }, - "optional": { - "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), - "mask": ("MASK", ), - "vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}), - - }, - } - - RETURN_TYPES = ("LATENT",) - RETURN_NAMES = ("samples",) - FUNCTION = "encode" - CATEGORY = "CogVideoWrapper" - - def encode(self, pipeline, start_image, end_image, enable_tiling=False, mask=None, vae_override=None): - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() - generator = torch.Generator(device=device).manual_seed(0) - - B, H, W, C = start_image.shape - - vae = pipeline["pipe"].vae if vae_override is None else vae_override - vae.enable_slicing() - model_name = pipeline.get("model_name", "") - - if ("1.5" in model_name or "1_5" in model_name): - vae_scaling_factor = 1 / vae.config.scaling_factor - else: - vae_scaling_factor = vae.config.scaling_factor - vae.enable_slicing() - - if enable_tiling: - from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling - enable_vae_encode_tiling(vae) - - if not pipeline["cpu_offloading"]: - vae.to(device) - - check_diffusers_version() - try: - vae._clear_fake_context_parallel_cache() - except: - pass - - if mask is not None: - pipeline["pipe"].original_mask = mask - # print(mask.shape) - # mask = mask.repeat(B, 1, 1) # Shape: [B, H, W] - # mask = mask.unsqueeze(-1).repeat(1, 1, 1, C) - # print(mask.shape) - # input_image = input_image * (1 -mask) - else: - pipeline["pipe"].original_mask = None - - start_image = (start_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W - end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) - B, T, C, H, W = start_image.shape - - latents_list = [] - - # Encode the chunk of images - start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae_scaling_factor - end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae_scaling_factor - - start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W - end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W - latents_list = [start_latents, end_latents] - - # Concatenate all the chunks along the temporal dimension - final_latents = torch.cat(latents_list, dim=1) - log.info(f"Encoded latents shape: {final_latents.shape}") - if not pipeline["cpu_offloading"]: - vae.to(offload_device) - - return ({"samples": final_latents}, ) + return ({ + "samples": final_latents, + "mask": mask + },) #region Tora from .tora.traj_utils import process_traj, scale_traj_list_to_256 @@ -480,8 +377,8 @@ class ToraEncodeTrajectory: @classmethod def INPUT_TYPES(s): return {"required": { - "pipeline": ("COGVIDEOPIPE",), "tora_model": ("TORAMODEL",), + "vae": ("VAE",), "coordinates": ("STRING", {"forceInput": True}), "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), @@ -491,7 +388,7 @@ class ToraEncodeTrajectory: "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, "optional": { - "enable_tiling": ("BOOLEAN", {"default": False}), + "enable_tiling": ("BOOLEAN", {"default": True}), } } @@ -500,14 +397,16 @@ class ToraEncodeTrajectory: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model, enable_tiling=False): + def encode(self, vae, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model, enable_tiling=False): check_diffusers_version() device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) - vae = pipeline["pipe"].vae - vae.enable_slicing() + try: + vae.enable_slicing() + except: + pass try: vae._clear_fake_context_parallel_cache() except: @@ -533,33 +432,26 @@ class ToraEncodeTrajectory: video_flow, points = process_traj(coords_list, num_frames, (height,width), device=device) video_flow = rearrange(video_flow, "T H W C -> T C H W") video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W] - - - video_flow = ( - rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype) - ) + video_flow = (rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype)) video_flow_image = rearrange(video_flow, "B C T H W -> (B T) H W C") - print(video_flow_image.shape) + #print(video_flow_image.shape) mm.soft_empty_cache() # VAE encode - if not pipeline["cpu_offloading"]: - vae.to(device) - + vae.to(device) video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor log.info(f"video_flow shape after encoding: {video_flow.shape}") #torch.Size([1, 16, 4, 80, 80]) + vae.to(offload_device) - if not pipeline["cpu_offloading"]: - vae.to(offload_device) + tora_model["traj_extractor"].to(device) #print("video_flow shape before traj_extractor: ", video_flow.shape) #torch.Size([1, 16, 4, 80, 80]) video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32)) + tora_model["traj_extractor"].to(offload_device) video_flow_features = torch.stack(video_flow_features) #print("video_flow_features after traj_extractor: ", video_flow_features.shape) #torch.Size([42, 4, 128, 40, 40]) video_flow_features = video_flow_features * strength - - tora = { "video_flow_features" : video_flow_features, "start_percent" : start_percent, @@ -574,7 +466,7 @@ class ToraEncodeOpticalFlow: @classmethod def INPUT_TYPES(s): return {"required": { - "pipeline": ("COGVIDEOPIPE",), + "vae": ("VAE",), "tora_model": ("TORAMODEL",), "optical_flow": ("IMAGE", ), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), @@ -589,15 +481,18 @@ class ToraEncodeOpticalFlow: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, optical_flow, strength, tora_model, start_percent, end_percent): + def encode(self, vae, optical_flow, strength, tora_model, start_percent, end_percent): check_diffusers_version() B, H, W, C = optical_flow.shape device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) - vae = pipeline["pipe"].vae - vae.enable_slicing() + try: + vae.enable_slicing() + except: + pass + try: vae._clear_fake_context_parallel_cache() except: @@ -609,15 +504,14 @@ class ToraEncodeOpticalFlow: mm.soft_empty_cache() # VAE encode - if not pipeline["cpu_offloading"]: - vae.to(device) + + vae.to(device) video_flow = video_flow.to(vae.dtype).to(vae.device) video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor vae.to(offload_device) video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32)) video_flow_features = torch.stack(video_flow_features) - video_flow_features = video_flow_features * strength log.info(f"video_flow shape: {video_flow.shape}") @@ -632,91 +526,7 @@ class ToraEncodeOpticalFlow: return (tora, ) -def add_noise_to_reference_video(image, ratio=None): - if ratio is None: - sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) - sigma = torch.exp(sigma).to(image.dtype) - else: - sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio - - image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] - image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) - image = image + image_noise - return image -class CogVideoControlImageEncode: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "pipeline": ("COGVIDEOPIPE",), - "control_video": ("IMAGE", ), - "base_resolution": ("INT", {"min": 64, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}), - "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), - "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }, - } - - RETURN_TYPES = ("COGCONTROL_LATENTS", "INT", "INT",) - RETURN_NAMES = ("control_latents", "width", "height") - FUNCTION = "encode" - CATEGORY = "CogVideoWrapper" - - def encode(self, pipeline, control_video, base_resolution, enable_tiling, noise_aug_strength=0.0563): - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() - - B, H, W, C = control_video.shape - - vae = pipeline["pipe"].vae - vae.enable_slicing() - - if enable_tiling: - from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling - enable_vae_encode_tiling(vae) - - if not pipeline["cpu_offloading"]: - vae.to(device) - - # Count most suitable height and width - aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} - - control_video = np.array(control_video.cpu().numpy() * 255, np.uint8) - original_width, original_height = Image.fromarray(control_video[0]).size - - closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) - height, width = [int(x / 16) * 16 for x in closest_size] - log.info(f"Closest bucket size: {width}x{height}") - - video_length = int((B - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if B != 1 else 1 - input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width)) - - control_video = pipeline["pipe"].image_processor.preprocess(rearrange(input_video, "b c f h w -> (b f) c h w"), height=height, width=width) - control_video = control_video.to(dtype=torch.float32) - control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) - - masked_image = control_video.to(device=device, dtype=vae.dtype) - if noise_aug_strength > 0: - masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) - bs = 1 - new_mask_pixel_values = [] - for i in range(0, masked_image.shape[0], bs): - mask_pixel_values_bs = masked_image[i : i + bs] - mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.mode() - new_mask_pixel_values.append(mask_pixel_values_bs) - masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) - masked_image_latents = masked_image_latents * vae.config.scaling_factor - - vae.to(offload_device) - - control_latents = { - "latents": masked_image_latents, - "num_frames" : B, - "height" : height, - "width" : width, - } - - return (control_latents, width, height) #region FasterCache class CogVideoXFasterCache: @@ -757,12 +567,10 @@ class CogVideoSampler: def INPUT_TYPES(s): return { "required": { - "pipeline": ("COGVIDEOPIPE",), + "model": ("COGVIDEOMODEL",), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), - "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 16}), - "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 16}), - "num_frames": ("INT", {"default": 49, "min": 17, "max": 1024, "step": 4}), + "num_frames": ("INT", {"default": 49, "min": 1, "max": 1024, "step": 1}), "steps": ("INT", {"default": 50, "min": 1}), "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), @@ -773,45 +581,54 @@ class CogVideoSampler: }, "optional": { "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), - "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "image_cond_latents": ("LATENT",{"tooltip": "Latent to use for image2video conditioning"} ), + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "context_options": ("COGCONTEXT", ), "controlnet": ("COGVIDECONTROLNET",), "tora_trajectory": ("TORAFEATURES", ), "fastercache": ("FASTERCACHEARGS", ), - #"sigmas": ("SIGMAS", ), } } - RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",) - RETURN_NAMES = ("cogvideo_pipe", "samples",) + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None, + def process(self, pipeline, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None): mm.soft_empty_cache() - base_path = pipeline["base_path"] model_name = pipeline.get("model_name", "") - supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() else False + supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() or "fun" in model_name.lower() else False - assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'" - assert ( - "I2V" not in model_name or - "1.5" in model_name or - "1_5" in model_name or - num_frames == 49 or - context_options is not None - ), "1.0 I2V model can only do 49 frames" + if "fun" in model_name.lower() and image_cond_latents is not None: + assert image_cond_latents["mask"] is not None, "For fun inpaint models use CogVideoImageEncodeFunInP" + fun_mask = image_cond_latents["mask"] + else: + fun_mask = None + if image_cond_latents is not None: assert supports_image_conds, "Image condition latents only supported for I2V and Interpolation models" - # if "I2V" in model_name: - # assert image_cond_latents["samples"].shape[1] == 1, "I2V model only supports single image condition latent" - # elif "interpolation" in model_name.lower(): - # assert image_cond_latents["samples"].shape[1] == 2, "Interpolation model needs two image condition latents" + image_conds = image_cond_latents["samples"] + if "1.5" in model_name or "1_5" in model_name: + image_conds = image_conds / 0.7 # needed for 1.5 models else: - assert not supports_image_conds, "Image condition latents required for I2V models" + if not "fun" in model_name.lower(): + assert not supports_image_conds, "Image condition latents required for I2V models" + image_conds = None + + if samples is not None: + if len(samples["samples"].shape) == 5: + B, T, C, H, W = samples["samples"].shape + latents = samples["samples"] + if len(samples["samples"].shape) == 4: + B, C, H, W = samples["samples"].shape + latents = None + if image_cond_latents is not None: + B, T, C, H, W = image_cond_latents["samples"].shape + height = H * 8 + width = W * 8 device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -861,9 +678,6 @@ class CogVideoSampler: cfg = [cfg for _ in range(steps)] else: assert len(cfg) == steps, "Length of cfg list must match number of steps" - - # if sigmas is not None: - # sigma_list = sigmas.tolist() try: torch.cuda.reset_peak_memory_stats(device) except: @@ -878,9 +692,9 @@ class CogVideoSampler: width = width, num_frames = num_frames, guidance_scale=cfg, - #sigmas=sigma_list if sigmas is not None else None, - latents=samples["samples"] if samples is not None else None, - image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None, + latents=latents if samples is not None else None, + fun_mask = fun_mask, + image_cond_latents=image_conds, denoise_strength=denoise_strength, prompt_embeds=positive.to(dtype).to(device), negative_prompt_embeds=negative.to(dtype).to(device), @@ -910,7 +724,11 @@ class CogVideoSampler: except: pass - return (pipeline, {"samples": latents}) + additional_frames = getattr(pipe, "additional_frames", 0) + return ({ + "samples": latents, + "additional_frames": additional_frames, + },) class CogVideoControlNet: @classmethod @@ -930,13 +748,7 @@ class CogVideoControlNet: CATEGORY = "CogVideoWrapper" def encode(self, controlnet, images, control_strength, control_start_percent, control_end_percent): - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() - - B, H, W, C = images.shape - control_frames = images.permute(0, 3, 1, 2).unsqueeze(0) * 2 - 1 - controlnet = { "control_model": controlnet, "control_frames": control_frames, @@ -944,7 +756,6 @@ class CogVideoControlNet: "control_start": control_start_percent, "control_end": control_end_percent, } - return (controlnet,) #region VideoDecode @@ -952,8 +763,8 @@ class CogVideoDecode: @classmethod def INPUT_TYPES(s): return {"required": { - "pipeline": ("COGVIDEOPIPE",), "samples": ("LATENT", ), + "vae": ("VAE", {"default": None}), "enable_vae_tiling": ("BOOLEAN", {"default": True, "tooltip": "Drastically reduces memory use but may introduce seams"}), }, "optional": { @@ -962,7 +773,6 @@ class CogVideoDecode: "tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), - "vae_override": ("VAE", {"default": None}), } } @@ -971,19 +781,20 @@ class CogVideoDecode: FUNCTION = "decode" CATEGORY = "CogVideoWrapper" - def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, - auto_tile_size=True, vae_override=None): + def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, + auto_tile_size=True, pipeline=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() latents = samples["samples"] - vae = pipeline["pipe"].vae if vae_override is None else vae_override + + additional_frames = samples.get("additional_frames", 0) - additional_frames = getattr(pipeline["pipe"], "additional_frames", 0) + try: + vae.enable_slicing() + except: + pass - vae.enable_slicing() - - if not pipeline["cpu_offloading"]: - vae.to(device) + vae.to(device) if enable_vae_tiling: if auto_tile_size: vae.enable_tiling() @@ -999,11 +810,11 @@ class CogVideoDecode: latents = latents.to(vae.dtype).to(device) latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / vae.config.scaling_factor * latents + try: vae._clear_fake_context_parallel_cache() except: pass - try: frames = vae.decode(latents[:, :, additional_frames:]).sample except: @@ -1013,11 +824,13 @@ class CogVideoDecode: frames = vae.decode(latents[:, :, additional_frames:]).sample vae.disable_tiling() - if not pipeline["cpu_offloading"]: - vae.to(offload_device) + vae.to(offload_device) mm.soft_empty_cache() - video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt") + video_processor = VideoProcessor(vae_scale_factor=8) + video_processor.config.do_resize = False + + video = video_processor.postprocess_video(video=frames, output_type="pt") video = video[0].permute(0, 2, 3, 1).cpu().float() return (video,) @@ -1041,6 +854,7 @@ class CogVideoXFunResizeToClosestBucket: def resize(self, images, base_resolution, upscale_method, crop): from comfy.utils import common_upscale + from .cogvideox_fun.utils import ASPECT_RATIO_512, get_closest_ratio B, H, W, C = images.shape # Count most suitable height and width @@ -1056,256 +870,6 @@ class CogVideoXFunResizeToClosestBucket: return (resized_images, width, height) -#region FunSamplers -class CogVideoXFunSampler: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "pipeline": ("COGVIDEOPIPE",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "video_length": ("INT", {"default": 49, "min": 5, "max": 2048, "step": 4}), - "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), - "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), - "seed": ("INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}), - "steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}), - "cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}), - "scheduler": (available_schedulers, {"default": 'DDIM'}) - }, - "optional":{ - "start_img": ("IMAGE",), - "end_img": ("IMAGE",), - "noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}), - "context_options": ("COGCONTEXT", ), - "tora_trajectory": ("TORAFEATURES", ), - "fastercache": ("FASTERCACHEARGS",), - "vid2vid_images": ("IMAGE",), - "vid2vid_denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }, - } - - RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",) - RETURN_NAMES = ("cogvideo_pipe", "samples",) - FUNCTION = "process" - CATEGORY = "CogVideoWrapper" - - def process(self, pipeline, positive, negative, video_length, width, height, seed, steps, cfg, scheduler, - start_img=None, end_img=None, noise_aug_strength=0.0563, context_options=None, fastercache=None, - tora_trajectory=None, vid2vid_images=None, vid2vid_denoise=1.0): - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() - pipe = pipeline["pipe"] - dtype = pipeline["dtype"] - base_path = pipeline["base_path"] - assert "fun" in base_path.lower(), "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'" - assert "pose" not in base_path.lower(), "'Pose' models not supported in 'CogVideoXFunSampler', use the 'CogVideoXFunControlSampler'" - - mm.soft_empty_cache() - - #vid2vid - if vid2vid_images is not None: - validation_video = np.array(vid2vid_images.cpu().numpy() * 255, np.uint8) - #img2vid - elif start_img is not None: - start_img = [to_pil(_start_img) for _start_img in start_img] if start_img is not None else None - end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None - - # Load Sampler - scheduler_config = pipeline["scheduler_config"] - if scheduler in scheduler_mapping: - noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) - pipe.scheduler = noise_scheduler - else: - raise ValueError(f"Unknown scheduler: {scheduler}") - - if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]: - pipe.transformer.to(device) - - if context_options is not None: - context_frames = context_options["context_frames"] // 4 - context_stride = context_options["context_stride"] // 4 - context_overlap = context_options["context_overlap"] // 4 - else: - context_frames, context_stride, context_overlap = None, None, None - - if tora_trajectory is not None: - pipe.transformer.fuser_list = tora_trajectory["fuser_list"] - - if fastercache is not None: - pipe.transformer.use_fastercache = True - pipe.transformer.fastercache_counter = 0 - pipe.transformer.fastercache_start_step = fastercache["start_step"] - pipe.transformer.fastercache_lf_step = fastercache["lf_step"] - pipe.transformer.fastercache_hf_step = fastercache["hf_step"] - pipe.transformer.fastercache_device = fastercache["cache_device"] - pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"] - log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}") - else: - pipe.transformer.use_fastercache = False - pipe.transformer.fastercache_counter = 0 - - generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) - - autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 - autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() - with autocast_context: - video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 - if vid2vid_images is not None: - input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=(height, width)) - else: - input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width)) - - common_params = { - "prompt_embeds": positive.to(dtype).to(device), - "negative_prompt_embeds": negative.to(dtype).to(device), - "num_frames": video_length, - "height": height, - "width": width, - "generator": generator, - "guidance_scale": cfg, - "num_inference_steps": steps, - "comfyui_progressbar": True, - "context_schedule":context_options["context_schedule"] if context_options is not None else None, - "context_frames":context_frames, - "context_stride": context_stride, - "context_overlap": context_overlap, - "freenoise":context_options["freenoise"] if context_options is not None else None, - "tora":tora_trajectory if tora_trajectory is not None else None, - } - latents = pipe( - **common_params, - video = input_video, - mask_video = input_video_mask, - noise_aug_strength = noise_aug_strength, - strength = vid2vid_denoise, - ) - if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]: - pipe.transformer.to(offload_device) - #clear FasterCache - if fastercache is not None: - for block in pipe.transformer.transformer_blocks: - if (hasattr, block, "cached_hidden_states") and block.cached_hidden_states is not None: - block.cached_hidden_states = None - block.cached_encoder_hidden_states = None - - mm.soft_empty_cache() - - return (pipeline, {"samples": latents}) - -class CogVideoXFunVid2VidSampler: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "note": ("STRING", {"default": "This node is deprecated, functionality moved to 'CogVideoXFunSampler' node instead.", "multiline": True}), - }, - } - - RETURN_TYPES = () - FUNCTION = "process" - CATEGORY = "CogVideoWrapper" - DEPRECATED = True - def process(self): - return () - -class CogVideoXFunControlSampler: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "pipeline": ("COGVIDEOPIPE",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_latents": ("COGCONTROL_LATENTS",), - "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}), - "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}), - "cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}), - "scheduler": (available_schedulers, {"default": 'DDIM'}), - "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }, - "optional": { - "samples": ("LATENT", ), - "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "context_options": ("COGCONTEXT", ), - }, - } - - RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",) - RETURN_NAMES = ("cogvideo_pipe", "samples",) - FUNCTION = "process" - CATEGORY = "CogVideoWrapper" - - def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents, - control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, - samples=None, denoise_strength=1.0, context_options=None): - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() - pipe = pipeline["pipe"] - dtype = pipeline["dtype"] - base_path = pipeline["base_path"] - - assert "fun" in base_path.lower(), "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'" - - if not pipeline["cpu_offloading"]: - pipe.enable_model_cpu_offload(device=device) - - mm.soft_empty_cache() - - if context_options is not None: - context_frames = context_options["context_frames"] // 4 - context_stride = context_options["context_stride"] // 4 - context_overlap = context_options["context_overlap"] // 4 - else: - context_frames, context_stride, context_overlap = None, None, None - - # Load Sampler - scheduler_config = pipeline["scheduler_config"] - if scheduler in scheduler_mapping: - noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) - pipe.scheduler = noise_scheduler - else: - raise ValueError(f"Unknown scheduler: {scheduler}") - - generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) - - autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 - autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() - with autocast_context: - - common_params = { - "prompt_embeds": positive.to(dtype).to(device), - "negative_prompt_embeds": negative.to(dtype).to(device), - "num_frames": control_latents["num_frames"], - "height": control_latents["height"], - "width": control_latents["width"], - "generator": generator, - "guidance_scale": cfg, - "num_inference_steps": steps, - "comfyui_progressbar": True, - } - - latents = pipe( - **common_params, - control_video=control_latents["latents"], - control_strength=control_strength, - control_start_percent=control_start_percent, - control_end_percent=control_end_percent, - scheduler_name=scheduler, - latents=samples["samples"] if samples is not None else None, - denoise_strength=denoise_strength, - context_schedule=context_options["context_schedule"] if context_options is not None else None, - context_frames=context_frames, - context_stride= context_stride, - context_overlap= context_overlap, - freenoise=context_options["freenoise"] if context_options is not None else None - - ) - - return (pipeline, {"samples": latents}) - class CogVideoLatentPreview: @classmethod def INPUT_TYPES(s): @@ -1332,9 +896,6 @@ class CogVideoLatentPreview: latents = samples["samples"].clone() print("in sample", latents.shape) latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() #[[0.0658900170023352, 0.04687556512203313, -0.056971557475649186], [-0.01265770449940036, -0.02814809569100843, -0.0768912512529372], [0.061456544746314665, 0.0005511617552452358, -0.0652574975291287], [-0.09020669168815276, -0.004755440180558637, -0.023763970904494294], [0.031766964513999865, -0.030959599938418375, 0.08654669098083616], [-0.005981764690055846, -0.08809119252349802, -0.06439852368217663], [-0.0212114426433989, 0.08894281999597677, 0.05155629477559985], [-0.013947446911030725, -0.08987475069900677, -0.08923124751217484], [-0.08235967967978511, 0.07268025379974379, 0.08830486164536037], [-0.08052049179735378, -0.050116143175332195, 0.02023752569687405], [-0.07607527759162447, 0.06827156419895981, 0.08678111754261035], [-0.04689089232553825, 0.017294986041038893, -0.10280492336438908], [-0.06105783150270304, 0.07311850680875913, 0.019995735372550075], [-0.09232589996527711, -0.012869815059053047, -0.04355587834255975], [-0.06679931010802251, 0.018399815879067458, 0.06802404982033876], [-0.013062632927118165, -0.04292991477896661, 0.07476243356192845]] latent_rgb_factors =[[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]] @@ -1374,15 +935,9 @@ NODE_CLASS_MAPPINGS = { "CogVideoSampler": CogVideoSampler, "CogVideoDecode": CogVideoDecode, "CogVideoTextEncode": CogVideoTextEncode, - "CogVideoDualTextEncode_311": CogVideoDualTextEncode_311, "CogVideoImageEncode": CogVideoImageEncode, - "CogVideoImageInterpolationEncode": CogVideoImageInterpolationEncode, - "CogVideoXFunSampler": CogVideoXFunSampler, - "CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler, - "CogVideoXFunControlSampler": CogVideoXFunControlSampler, "CogVideoTextEncodeCombine": CogVideoTextEncodeCombine, "CogVideoTransformerEdit": CogVideoTransformerEdit, - "CogVideoControlImageEncode": CogVideoControlImageEncode, "CogVideoContextOptions": CogVideoContextOptions, "CogVideoControlNet": CogVideoControlNet, "ToraEncodeTrajectory": ToraEncodeTrajectory, @@ -1390,21 +945,16 @@ NODE_CLASS_MAPPINGS = { "CogVideoXFasterCache": CogVideoXFasterCache, "CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket, "CogVideoLatentPreview": CogVideoLatentPreview, - "CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings + "CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings, + "CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP, } NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoSampler": "CogVideo Sampler", "CogVideoDecode": "CogVideo Decode", "CogVideoTextEncode": "CogVideo TextEncode", - "CogVideoDualTextEncode_311": "CogVideo DualTextEncode", "CogVideoImageEncode": "CogVideo ImageEncode", - "CogVideoImageInterpolationEncode": "CogVideo ImageInterpolation Encode", - "CogVideoXFunSampler": "CogVideoXFun Sampler", - "CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler", - "CogVideoXFunControlSampler": "CogVideoXFun Control Sampler", "CogVideoTextEncodeCombine": "CogVideo TextEncode Combine", "CogVideoTransformerEdit": "CogVideo TransformerEdit", - "CogVideoControlImageEncode": "CogVideo Control ImageEncode", "CogVideoContextOptions": "CogVideo Context Options", "ToraEncodeTrajectory": "Tora Encode Trajectory", "ToraEncodeOpticalFlow": "Tora Encode OpticalFlow", @@ -1412,4 +962,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket", "CogVideoLatentPreview": "CogVideo LatentPreview", "CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings", + "CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP", } \ No newline at end of file diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 472a308..f869ce4 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -17,15 +17,13 @@ import inspect from typing import Callable, Dict, List, Optional, Tuple, Union import torch -import torch.nn.functional as F import math -from diffusers.models import AutoencoderKLCogVideoX from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor -from diffusers.video_processor import VideoProcessor + #from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.loaders import CogVideoXLoraLoaderMixin @@ -120,15 +118,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - text_encoder ([`T5EncoderModel`]): - Frozen text-encoder. CogVideoX uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the - [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. - tokenizer (`T5Tokenizer`): - Tokenizer of class - [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). transformer ([`CogVideoXTransformer3DModel`]): A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): @@ -140,31 +129,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): def __init__( self, - vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], - original_mask = None, + dtype: torch.dtype = torch.bfloat16, + is_fun_inpaint: bool = False, ): super().__init__() - self.register_modules( - vae=vae, transformer=transformer, scheduler=scheduler - ) - self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - self.original_mask = original_mask - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - self.video_processor.config.do_resize = False + self.register_modules(transformer=transformer, scheduler=scheduler) + self.vae_scale_factor_spatial = 8 + self.vae_scale_factor_temporal = 4 + self.vae_latent_channels = 16 + self.vae_dtype = dtype + self.is_fun_inpaint = is_fun_inpaint self.input_with_padding = True def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, + self, batch_size, num_channels_latents, num_frames, height, width, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None ): shape = ( @@ -174,14 +157,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype) + + noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae_dtype) if freenoise: - print("Applying FreeNoise") + logger.info("Applying FreeNoise") # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) video_length = num_frames // 4 delta = context_size - context_overlap @@ -221,20 +200,20 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) latent_timestep = timesteps[:1] - noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) frames_needed = noise.shape[1] current_frames = latents.shape[1] if frames_needed > current_frames: - repeat_factor = frames_needed // current_frames + repeat_factor = frames_needed - current_frames additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device) - latents = torch.cat((latents, additional_frame), dim=1) + latents = torch.cat((additional_frame, latents), dim=1) + self.additional_frames = repeat_factor elif frames_needed < current_frames: latents = latents[:, :frames_needed, :, :, :] - latents = self.scheduler.add_noise(latents, noise, latent_timestep) + latents = self.scheduler.add_noise(latents, noise.to(device), latent_timestep) latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler - return latents, timesteps, noise + return latents, timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -355,10 +334,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): guidance_scale: float = 6, denoise_strength: float = 1.0, sigmas: Optional[List[float]] = None, - num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + fun_mask: Optional[torch.Tensor] = None, image_cond_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, @@ -398,8 +377,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -443,7 +420,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_embeds = prompt_embeds.to(self.vae.dtype) + prompt_embeds = prompt_embeds.to(self.vae_dtype) # 4. Prepare timesteps if sigmas is None: @@ -453,7 +430,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): self._num_timesteps = len(timesteps) # 5. Prepare latents. - latent_channels = self.vae.config.latent_channels + latent_channels = self.vae_latent_channels latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t @@ -469,18 +446,12 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): self.additional_frames = patch_size_t - latent_frames % patch_size_t num_frames += self.additional_frames * self.vae_scale_factor_temporal - - if self.original_mask is not None: - image_latents = latents - original_image_latents = image_latents - - latents, timesteps, noise = self.prepare_latents( - batch_size * num_videos_per_prompt, + latents, timesteps = self.prepare_latents( + batch_size, latent_channels, num_frames, height, width, - self.vae.dtype, device, generator, timesteps, @@ -491,37 +462,41 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): context_overlap=context_overlap, freenoise=freenoise, ) - latents = latents.to(self.vae.dtype) + latents = latents.to(self.vae_dtype) + + if self.is_fun_inpaint and fun_mask is None: # For FUN inpaint vid2vid, we need to mask all the latents + fun_mask = torch.zeros_like(latents[:, :, :1, :, :], device=latents.device, dtype=latents.dtype) + fun_masked_video_latents = torch.zeros_like(latents, device=latents.device, dtype=latents.dtype) # 5.5. if image_cond_latents is not None: - if image_cond_latents.shape[1] > 1: + if image_cond_latents.shape[1] == 2: logger.info("More than one image conditioning frame received, interpolating") padding_shape = ( - batch_size, - (latents.shape[1] - 2), - self.vae.config.latent_channels, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, + batch_size, + (latents.shape[1] - 2), + self.vae_latent_channels, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, ) - latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype) + latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype) image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1) if self.transformer.config.patch_size_t is not None: - first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...] - image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) + first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...] + image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) logger.info(f"image cond latents shape: {image_cond_latents.shape}") - else: + elif image_cond_latents.shape[1] == 1: logger.info("Only one image conditioning frame received, img2vid") if self.input_with_padding: padding_shape = ( batch_size, (latents.shape[1] - 1), - self.vae.config.latent_channels, + self.vae_latent_channels, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) - latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype) + latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype) image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1) # Select the first frame along the second dimension if self.transformer.config.patch_size_t is not None: @@ -529,22 +504,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) else: image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1) + else: + logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames") + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # masks - if self.original_mask is not None: - mask = self.original_mask.to(device) - logger.info(f"self.original_mask: {self.original_mask.shape}") - - mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False) - if mask.shape[0] != latents.shape[1]: - mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1) - else: - mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1) - logger.info(f"latents: {latents.shape}") - logger.info(f"mask: {mask.shape}") - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -554,7 +518,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): raise NotImplementedError("Context schedule not currently supported with image conditioning") logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") use_context_schedule = True - from .cogvideox_fun.context import get_context_scheduler + from .context import get_context_scheduler context = get_context_scheduler(context_schedule) #todo ofs embeds? @@ -747,7 +711,18 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): if image_cond_latents is not None: latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents - latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + if fun_mask is not None: #for fun img2vid and interpolation + fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask + masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2) + latent_model_input = torch.cat([latent_model_input, masks_input], dim=2) + else: + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + else: # for Fun inpaint vid2vid + if fun_mask is not None: + fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask + fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents + fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype) + latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -767,9 +742,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): return_dict=False, )[0] if isinstance(controlnet_states, (tuple, list)): - controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states] + controlnet_states = [x.to(dtype=self.vae_dtype) for x in controlnet_states] else: - controlnet_states = controlnet_states.to(dtype=self.vae.dtype) + controlnet_states = controlnet_states.to(dtype=self.vae_dtype) # predict noise model_output @@ -796,30 +771,18 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): # compute the previous noisy sample x_t -> x_t-1 if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents.to(self.vae_dtype), **extra_step_kwargs, return_dict=False)[0] else: latents, old_pred_original_sample = self.scheduler.step( noise_pred, old_pred_original_sample, t, timesteps[i - 1] if i > 0 else None, - latents.to(self.vae.dtype), + latents.to(self.vae_dtype), **extra_step_kwargs, return_dict=False, ) latents = latents.to(prompt_embeds.dtype) - # start diff diff - if i < len(timesteps) - 1 and self.original_mask is not None: - noise_timestep = timesteps[i + 1] - image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep]) - ) - mask = mask.to(latents) - ts_from = timesteps[0] - ts_to = timesteps[-1] - threshold = (t - ts_to) / (ts_from - ts_to) - mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask)) - latents = image_latent * mask + latents * (1 - mask) - # end diff diff if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/pyproject.toml b/pyproject.toml index 78b9ed8..1ca05f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [project] name = "comfyui-cogvideoxwrapper" description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)" -version = "1.1.0" +version = "1.5.0" license = {file = "LICENSE"} -dependencies = ["huggingface_hub", "diffusers>=0.30.1", "accelerate>=0.33.0"] +dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"] [project.urls] Repository = "https://github.com/kijai/ComfyUI-CogVideoXWrapper"