From 0bd3da569ead6b36982dbeb55500a1c4963d137d Mon Sep 17 00:00:00 2001
From: kijai <40791699+kijai@users.noreply.github.com>
Date: Thu, 14 Nov 2024 19:54:52 +0200
Subject: [PATCH] code cleanup
codebase getting too bloated:
drop PAB support in favor of FasterCache
drop temporal tilling in favor of FreeNoise
---
cogvideox_fun/fun_pab_transformer_3d.py | 741 --------------------
cogvideox_fun/pipeline_cogvideox_control.py | 116 +--
cogvideox_fun/pipeline_cogvideox_inpaint.py | 126 +---
model_loading.py | 50 +-
nodes.py | 64 +-
pipeline_cogvideox.py | 154 +---
videosys/cogvideox_transformer_3d.py | 621 ----------------
videosys/core/__init__.py | 0
videosys/core/pab_mgr.py | 232 ------
videosys/core/pipeline.py | 44 --
videosys/modules/__init__.py | 0
videosys/modules/activations.py | 3 -
videosys/modules/downsampling.py | 71 --
videosys/modules/embeddings.py | 308 --------
videosys/modules/normalization.py | 85 ---
videosys/modules/upsampling.py | 67 --
videosys/pab.py | 64 --
17 files changed, 35 insertions(+), 2711 deletions(-)
delete mode 100644 cogvideox_fun/fun_pab_transformer_3d.py
delete mode 100644 videosys/cogvideox_transformer_3d.py
delete mode 100644 videosys/core/__init__.py
delete mode 100644 videosys/core/pab_mgr.py
delete mode 100644 videosys/core/pipeline.py
delete mode 100644 videosys/modules/__init__.py
delete mode 100644 videosys/modules/activations.py
delete mode 100644 videosys/modules/downsampling.py
delete mode 100644 videosys/modules/embeddings.py
delete mode 100644 videosys/modules/normalization.py
delete mode 100644 videosys/modules/upsampling.py
delete mode 100644 videosys/pab.py
diff --git a/cogvideox_fun/fun_pab_transformer_3d.py b/cogvideox_fun/fun_pab_transformer_3d.py
deleted file mode 100644
index 25a3934..0000000
--- a/cogvideox_fun/fun_pab_transformer_3d.py
+++ /dev/null
@@ -1,741 +0,0 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Dict, Optional, Tuple, Union
-
-import os
-import json
-import torch
-import glob
-import torch.nn.functional as F
-from torch import nn
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.utils import is_torch_version, logging
-from diffusers.utils.torch_utils import maybe_allow_in_graph
-from diffusers.models.attention import Attention, FeedForward
-#from diffusers.models.attention_processor import AttentionProcessor
-from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
-from diffusers.models.modeling_outputs import Transformer2DModelOutput
-from diffusers.models.modeling_utils import ModelMixin
-#from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
-
-from ..videosys.modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
-from ..videosys.modules.embeddings import apply_rotary_emb
-from ..videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-try:
- from sageattention import sageattn
- SAGEATTN_IS_AVAVILABLE = True
- logger.info("Using sageattn")
-except:
- logger.info("sageattn not found, using sdpa")
- SAGEATTN_IS_AVAVILABLE = False
-
-class CogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- # if attn.parallel_manager.sp_size > 1:
- # assert (
- # attn.heads % attn.parallel_manager.sp_size == 0
- # ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
- # attn_heads = attn.heads // attn.parallel_manager.sp_size
- # query, key, value = map(
- # lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
- # [query, key, value],
- # )
-
- attn_heads = attn.heads
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn_heads
-
- query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- emb_len = image_rotary_emb[0].shape[0]
- query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
- query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
- )
- if not attn.is_cross_attention:
- key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
- key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
- )
-
- if SAGEATTN_IS_AVAVILABLE:
- hidden_states = sageattn(query, key, value, is_causal=False)
- else:
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
-
- #if attn.parallel_manager.sp_size > 1:
- # hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-
-class FusedCogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
- if not attn.is_cross_attention:
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
- if SAGEATTN_IS_AVAVILABLE:
- hidden_states = sageattn(query, key, value, is_causal=False)
- else:
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-class CogVideoXPatchEmbed(nn.Module):
- def __init__(
- self,
- patch_size: int = 2,
- in_channels: int = 16,
- embed_dim: int = 1920,
- text_embed_dim: int = 4096,
- bias: bool = True,
- ) -> None:
- super().__init__()
- self.patch_size = patch_size
-
- self.proj = nn.Conv2d(
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
- )
- self.text_proj = nn.Linear(text_embed_dim, embed_dim)
-
- def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
- r"""
- Args:
- text_embeds (`torch.Tensor`):
- Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
- image_embeds (`torch.Tensor`):
- Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
- """
- text_embeds = self.text_proj(text_embeds)
-
- batch, num_frames, channels, height, width = image_embeds.shape
- image_embeds = image_embeds.reshape(-1, channels, height, width)
- image_embeds = self.proj(image_embeds)
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
-
- embeds = torch.cat(
- [text_embeds, image_embeds], dim=1
- ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
- return embeds
-
-@maybe_allow_in_graph
-class CogVideoXBlock(nn.Module):
- r"""
- Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
-
- Parameters:
- dim (`int`):
- The number of channels in the input and output.
- num_attention_heads (`int`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`):
- The number of channels in each head.
- time_embed_dim (`int`):
- The number of channels in timestep embedding.
- dropout (`float`, defaults to `0.0`):
- The dropout probability to use.
- activation_fn (`str`, defaults to `"gelu-approximate"`):
- Activation function to be used in feed-forward.
- attention_bias (`bool`, defaults to `False`):
- Whether or not to use bias in attention projection layers.
- qk_norm (`bool`, defaults to `True`):
- Whether or not to use normalization after query and key projections in Attention.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether to use learnable elementwise affine parameters for normalization.
- norm_eps (`float`, defaults to `1e-5`):
- Epsilon value for normalization layers.
- final_dropout (`bool` defaults to `False`):
- Whether to apply a final dropout after the last feed-forward layer.
- ff_inner_dim (`int`, *optional*, defaults to `None`):
- Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
- ff_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Feed-forward layer.
- attention_out_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Attention output projection layer.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- time_embed_dim: int,
- dropout: float = 0.0,
- activation_fn: str = "gelu-approximate",
- attention_bias: bool = False,
- qk_norm: bool = True,
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- final_dropout: bool = True,
- ff_inner_dim: Optional[int] = None,
- ff_bias: bool = True,
- attention_out_bias: bool = True,
- block_idx: int = 0,
- ):
- super().__init__()
-
- # 1. Self Attention
- self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.attn1 = Attention(
- query_dim=dim,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- qk_norm="layer_norm" if qk_norm else None,
- eps=1e-6,
- bias=attention_bias,
- out_bias=attention_out_bias,
- processor=CogVideoXAttnProcessor2_0(),
- )
-
- # parallel
- #self.attn1.parallel_manager = None
-
- # 2. Feed Forward
- self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- inner_dim=ff_inner_dim,
- bias=ff_bias,
- )
-
- # pab
- self.attn_count = 0
- self.last_attn = None
- self.block_idx = block_idx
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- temb: torch.Tensor,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- timestep=None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- # norm & modulate
- norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
- hidden_states, encoder_hidden_states, temb
- )
-
- # attention
- if enable_pab():
- broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
- if enable_pab() and broadcast_attn:
- attn_hidden_states, attn_encoder_hidden_states = self.last_attn
- else:
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=norm_encoder_hidden_states,
- image_rotary_emb=image_rotary_emb,
- )
- if enable_pab():
- self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
-
- hidden_states = hidden_states + gate_msa * attn_hidden_states
- encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
-
- # norm & modulate
- norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
- hidden_states, encoder_hidden_states, temb
- )
-
- # feed-forward
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
- ff_output = self.ff(norm_hidden_states)
-
- hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
- encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
-
- return hidden_states, encoder_hidden_states
-
-
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
- """
- A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
-
- Parameters:
- num_attention_heads (`int`, defaults to `30`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`, defaults to `64`):
- The number of channels in each head.
- in_channels (`int`, defaults to `16`):
- The number of channels in the input.
- out_channels (`int`, *optional*, defaults to `16`):
- The number of channels in the output.
- flip_sin_to_cos (`bool`, defaults to `True`):
- Whether to flip the sin to cos in the time embedding.
- time_embed_dim (`int`, defaults to `512`):
- Output dimension of timestep embeddings.
- text_embed_dim (`int`, defaults to `4096`):
- Input dimension of text embeddings from the text encoder.
- num_layers (`int`, defaults to `30`):
- The number of layers of Transformer blocks to use.
- dropout (`float`, defaults to `0.0`):
- The dropout probability to use.
- attention_bias (`bool`, defaults to `True`):
- Whether or not to use bias in the attention projection layers.
- sample_width (`int`, defaults to `90`):
- The width of the input latents.
- sample_height (`int`, defaults to `60`):
- The height of the input latents.
- sample_frames (`int`, defaults to `49`):
- The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
- instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
- but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
- K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
- patch_size (`int`, defaults to `2`):
- The size of the patches to use in the patch embedding layer.
- temporal_compression_ratio (`int`, defaults to `4`):
- The compression ratio across the temporal dimension. See documentation for `sample_frames`.
- max_text_seq_length (`int`, defaults to `226`):
- The maximum sequence length of the input text embeddings.
- activation_fn (`str`, defaults to `"gelu-approximate"`):
- Activation function to use in feed-forward.
- timestep_activation_fn (`str`, defaults to `"silu"`):
- Activation function to use when generating the timestep embeddings.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether or not to use elementwise affine in normalization layers.
- norm_eps (`float`, defaults to `1e-5`):
- The epsilon value to use in normalization layers.
- spatial_interpolation_scale (`float`, defaults to `1.875`):
- Scaling factor to apply in 3D positional embeddings across spatial dimensions.
- temporal_interpolation_scale (`float`, defaults to `1.0`):
- Scaling factor to apply in 3D positional embeddings across temporal dimensions.
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- num_attention_heads: int = 30,
- attention_head_dim: int = 64,
- in_channels: int = 16,
- out_channels: Optional[int] = 16,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- time_embed_dim: int = 512,
- text_embed_dim: int = 4096,
- num_layers: int = 30,
- dropout: float = 0.0,
- attention_bias: bool = True,
- sample_width: int = 90,
- sample_height: int = 60,
- sample_frames: int = 49,
- patch_size: int = 2,
- temporal_compression_ratio: int = 4,
- max_text_seq_length: int = 226,
- activation_fn: str = "gelu-approximate",
- timestep_activation_fn: str = "silu",
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- spatial_interpolation_scale: float = 1.875,
- temporal_interpolation_scale: float = 1.0,
- use_rotary_positional_embeddings: bool = False,
- add_noise_in_inpaint_model: bool = False,
- ):
- super().__init__()
- inner_dim = num_attention_heads * attention_head_dim
-
- post_patch_height = sample_height // patch_size
- post_patch_width = sample_width // patch_size
- post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
- self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
- self.post_patch_height = post_patch_height
- self.post_patch_width = post_patch_width
- self.post_time_compression_frames = post_time_compression_frames
- self.patch_size = patch_size
-
- # 1. Patch embedding
- self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
- self.embedding_dropout = nn.Dropout(dropout)
-
- # 2. 3D positional embeddings
- spatial_pos_embedding = get_3d_sincos_pos_embed(
- inner_dim,
- (post_patch_width, post_patch_height),
- post_time_compression_frames,
- spatial_interpolation_scale,
- temporal_interpolation_scale,
- )
- spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
- pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
- pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
- self.register_buffer("pos_embedding", pos_embedding, persistent=False)
-
- # 3. Time embeddings
- self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
-
- # 4. Define spatio-temporal transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- CogVideoXBlock(
- dim=inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- time_embed_dim=time_embed_dim,
- dropout=dropout,
- activation_fn=activation_fn,
- attention_bias=attention_bias,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- )
- for _ in range(num_layers)
- ]
- )
- self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
-
- # 5. Output blocks
- self.norm_out = AdaLayerNorm(
- embedding_dim=time_embed_dim,
- output_dim=2 * inner_dim,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- chunk_dim=1,
- )
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
-
- self.gradient_checkpointing = False
-
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- timestep: Union[int, float, torch.LongTensor],
- timestep_cond: Optional[torch.Tensor] = None,
- inpaint_latents: Optional[torch.Tensor] = None,
- control_latents: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- return_dict: bool = True,
- ):
- batch_size, num_frames, channels, height, width = hidden_states.shape
-
- # 1. Time embedding
- timesteps = timestep
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=hidden_states.dtype)
- emb = self.time_embedding(t_emb, timestep_cond)
-
- # 2. Patch embedding
- if inpaint_latents is not None:
- hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
- if control_latents is not None:
- hidden_states = torch.concat([hidden_states, control_latents], 2)
- hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
-
- # 3. Position embedding
- text_seq_length = encoder_hidden_states.shape[1]
- if not self.config.use_rotary_positional_embeddings:
- seq_length = height * width * num_frames // (self.config.patch_size**2)
- # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
- pos_embeds = self.pos_embedding
- emb_size = hidden_states.size()[-1]
- pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
- pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
- pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
- pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
- hidden_states = hidden_states + pos_embeds
- hidden_states = self.embedding_dropout(hidden_states)
-
- encoder_hidden_states = hidden_states[:, :text_seq_length]
- hidden_states = hidden_states[:, text_seq_length:]
-
- # 4. Transformer blocks
-
- for i, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- encoder_hidden_states,
- emb,
- image_rotary_emb,
- **ckpt_kwargs,
- )
- else:
- hidden_states, encoder_hidden_states = block(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=emb,
- image_rotary_emb=image_rotary_emb,
- timestep=timestep,
- )
-
- if not self.config.use_rotary_positional_embeddings:
- # CogVideoX-2B
- hidden_states = self.norm_final(hidden_states)
- else:
- # CogVideoX-5B
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- hidden_states = self.norm_final(hidden_states)
- hidden_states = hidden_states[:, text_seq_length:]
-
- # 5. Final block
- hidden_states = self.norm_out(hidden_states, temb=emb)
- hidden_states = self.proj_out(hidden_states)
-
- # 6. Unpatchify
- p = self.config.patch_size
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
-
- if not return_dict:
- return (output,)
- return Transformer2DModelOutput(sample=output)
-
- @classmethod
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
- if subfolder is not None:
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
-
- config_file = os.path.join(pretrained_model_path, 'config.json')
- if not os.path.isfile(config_file):
- raise RuntimeError(f"{config_file} does not exist")
- with open(config_file, "r") as f:
- config = json.load(f)
-
- from diffusers.utils import WEIGHTS_NAME
- model = cls.from_config(config, **transformer_additional_kwargs)
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
- if os.path.exists(model_file):
- state_dict = torch.load(model_file, map_location="cpu")
- elif os.path.exists(model_file_safetensors):
- from safetensors.torch import load_file, safe_open
- state_dict = load_file(model_file_safetensors)
- else:
- from safetensors.torch import load_file, safe_open
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
- state_dict = {}
- for model_file_safetensors in model_files_safetensors:
- _state_dict = load_file(model_file_safetensors)
- for key in _state_dict:
- state_dict[key] = _state_dict[key]
-
- if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
- new_shape = model.state_dict()['patch_embed.proj.weight'].size()
- if len(new_shape) == 5:
- state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
- state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
- else:
- if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
- model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
- model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
- else:
- model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
-
- tmp_state_dict = {}
- for key in state_dict:
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
- tmp_state_dict[key] = state_dict[key]
- else:
- print(key, "Size don't match, skip")
- state_dict = tmp_state_dict
-
- m, u = model.load_state_dict(state_dict, strict=False)
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
- print(m)
-
- params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
- print(f"### Mamba Parameters: {sum(params) / 1e6} M")
-
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
-
- return model
\ No newline at end of file
diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py
index 85687fe..f598147 100644
--- a/cogvideox_fun/pipeline_cogvideox_control.py
+++ b/cogvideox_fun/pipeline_cogvideox_control.py
@@ -33,10 +33,6 @@ from diffusers.video_processor import VideoProcessor
from diffusers.image_processor import VaeImageProcessor
from einops import rearrange
-from ..videosys.core.pipeline import VideoSysPipeline
-from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
-from ..videosys.core.pab_mgr import set_pab_manager
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -158,7 +154,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
videos: torch.Tensor
-class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
+class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
@@ -188,7 +184,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
- pab_config = None
):
super().__init__()
@@ -210,9 +205,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
- if pab_config is not None:
- set_pab_manager(pab_config)
-
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
latents=None, freenoise=True, context_size=None, context_overlap=None
@@ -348,16 +340,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
-
- def _gaussian_weights(self, t_tile_length, t_batch_size):
- from numpy import pi, exp, sqrt
-
- var = 0.01
- midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
- t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
- weights = torch.tensor(t_probs)
- weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
- return weights
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
@@ -697,24 +679,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
- # 8.5. Temporal tiling prep
- if context_schedule is not None and context_schedule == "temporal_tiling":
- t_tile_length = context_frames
- t_tile_overlap = context_overlap
- t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
- use_temporal_tiling = True
- print("Temporal tiling enabled")
- elif context_schedule is not None:
+ if context_schedule is not None:
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
- use_temporal_tiling = False
use_context_schedule = True
from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
else:
- use_temporal_tiling = False
use_context_schedule = False
- print("Temporal tiling and context schedule disabled")
+ print(" context schedule disabled")
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
@@ -735,88 +708,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
for i, t in enumerate(timesteps):
if self.interrupt:
continue
-
- if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
- #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
- # =====================================================
- grid_ts = 0
- cur_t = 0
- while cur_t < latents.shape[1]:
- cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
- grid_ts += 1
-
- all_t = latents.shape[1]
- latents_all_list = []
- # =====================================================
-
- image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
- if self.transformer.config.use_rotary_positional_embeddings
- else None
- )
-
- for t_i in range(grid_ts):
- if t_i < grid_ts - 1:
- ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
- if t_i == grid_ts - 1:
- ofs_t = all_t - t_tile_length
-
- input_start_t = ofs_t
- input_end_t = ofs_t + t_tile_length
-
- latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
- control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :]
-
- latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
- latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
-
- #t_input = t[None].to(device)
- t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
-
- # predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input_tile,
- encoder_hidden_states=prompt_embeds,
- timestep=t_input,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- control_latents=control_latents_tile,
- )[0]
- noise_pred = noise_pred.float()
-
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
- latents_all_list.append(latents_tile)
-
- # ==========================================
- latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
- contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
- # Add each tile contribution to overall latents
- for t_i in range(grid_ts):
- if t_i < grid_ts - 1:
- ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
- if t_i == grid_ts - 1:
- ofs_t = all_t - t_tile_length
-
- input_start_t = ofs_t
- input_end_t = ofs_t + t_tile_length
-
- latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
- contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
-
- latents_all /= contributors
-
- latents = latents_all
-
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- pbar.update(1)
- # ==========================================
- elif use_context_schedule:
+ if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py
index 4b3b4f3..7b9d8e7 100644
--- a/cogvideox_fun/pipeline_cogvideox_inpaint.py
+++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py
@@ -33,11 +33,6 @@ from diffusers.video_processor import VideoProcessor
from diffusers.image_processor import VaeImageProcessor
from einops import rearrange
-from ..videosys.core.pipeline import VideoSysPipeline
-from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
-from ..videosys.core.pab_mgr import set_pab_manager
-
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -206,7 +201,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
videos: torch.Tensor
-class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
+class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
@@ -236,7 +231,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
- pab_config = None
):
super().__init__()
@@ -258,9 +252,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
- if pab_config is not None:
- set_pab_manager(pab_config)
-
def prepare_latents(
self,
batch_size,
@@ -433,16 +424,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
- def _gaussian_weights(self, t_tile_length, t_batch_size):
- from numpy import pi, exp, sqrt
-
- var = 0.01
- midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
- t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
- weights = torch.tensor(t_probs)
- weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
- return weights
-
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
@@ -866,22 +847,14 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
- if context_schedule is not None and context_schedule == "temporal_tiling":
- t_tile_length = context_frames
- t_tile_overlap = context_overlap
- t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
- use_temporal_tiling = True
- print("Temporal tiling enabled")
- elif context_schedule is not None:
+ if context_schedule is not None:
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
- use_temporal_tiling = False
use_context_schedule = True
from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
else:
- use_temporal_tiling = False
use_context_schedule = False
- print("Temporal tiling and context schedule disabled")
+ print("context schedule disabled")
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
@@ -915,87 +888,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
if self.interrupt:
continue
- if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
- #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
- # =====================================================
- grid_ts = 0
- cur_t = 0
- while cur_t < latents.shape[1]:
- cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
- grid_ts += 1
-
- all_t = latents.shape[1]
- latents_all_list = []
- # =====================================================
-
- image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
- if self.transformer.config.use_rotary_positional_embeddings
- else None
- )
-
- for t_i in range(grid_ts):
- if t_i < grid_ts - 1:
- ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
- if t_i == grid_ts - 1:
- ofs_t = all_t - t_tile_length
-
- input_start_t = ofs_t
- input_end_t = ofs_t + t_tile_length
-
- latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
- inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :]
-
- latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
- latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
-
- #t_input = t[None].to(device)
- t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
-
- # predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input_tile,
- encoder_hidden_states=prompt_embeds,
- timestep=t_input,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- inpaint_latents=inpaint_latents_tile,
- )[0]
- noise_pred = noise_pred.float()
-
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
- latents_all_list.append(latents_tile)
-
- # ==========================================
- latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
- contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
- # Add each tile contribution to overall latents
- for t_i in range(grid_ts):
- if t_i < grid_ts - 1:
- ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
- if t_i == grid_ts - 1:
- ofs_t = all_t - t_tile_length
-
- input_start_t = ofs_t
- input_end_t = ofs_t + t_tile_length
-
- latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
- contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
-
- latents_all /= contributors
-
- latents = latents_all
-
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- pbar.update(1)
- # ==========================================
- elif use_context_schedule:
+ if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1133,18 +1026,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
else:
pbar.update(1)
- # if output_type == "numpy":
- # video = self.decode_latents(latents)
- # elif not output_type == "latent":
- # video = self.decode_latents(latents)
- # video = self.video_processor.postprocess_video(video=video, output_type=output_type)
- # else:
- # video = latents
-
# Offload all models
self.maybe_free_model_hooks()
- # if not return_dict:
- # video = torch.from_numpy(video)
-
return latents
\ No newline at end of file
diff --git a/model_loading.py b/model_loading.py
index fe9245b..532d6aa 100644
--- a/model_loading.py
+++ b/model_loading.py
@@ -12,15 +12,12 @@ from .pipeline_cogvideox import CogVideoXPipeline
from contextlib import nullcontext
from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun
-from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFunPAB
from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
-from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
-
-from .utils import check_diffusers_version, remove_specific_blocks, log
+from .utils import remove_specific_blocks, log
from comfy.utils import load_torch_file
script_directory = os.path.dirname(os.path.abspath(__file__))
@@ -95,7 +92,6 @@ class DownloadAndLoadCogVideoModel:
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
- "pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
@@ -111,7 +107,7 @@ class DownloadAndLoadCogVideoModel:
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
- enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
+ enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
if precision == "fp16" and "1.5" in model:
@@ -188,15 +184,9 @@ class DownloadAndLoadCogVideoModel:
# transformer
if "Fun" in model:
- if pab_config is not None:
- transformer = CogVideoXTransformer3DModelFunPAB.from_pretrained(base_path, subfolder=subfolder)
- else:
- transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
+ transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
else:
- if pab_config is not None:
- transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder=subfolder)
- else:
- transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
+ transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = transformer.to(dtype).to(transformer_load_device)
@@ -213,12 +203,12 @@ class DownloadAndLoadCogVideoModel:
if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
if "Pose" in model:
- pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
+ pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
else:
- pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
+ pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
- pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
+ pipe = CogVideoXPipeline(vae, transformer, scheduler)
if "cogvideox-2b-img2vid" in model:
pipe.input_with_padding = False
@@ -296,7 +286,7 @@ class DownloadAndLoadCogVideoModel:
backend="nexfort",
options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
ignores=["vae"],
- fuse_qkv_projections=True if pab_config is None else False,
+ fuse_qkv_projections= False,
)
pipeline = {
@@ -334,7 +324,6 @@ class DownloadAndLoadCogVideoGGUFModel:
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
},
"optional": {
- "pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
#"lora": ("COGLORA", {"default": None}),
"compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
@@ -348,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel:
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
- pab_config=None, block_edit=None, compile="disabled", attention_mode="sdpa"):
+ block_edit=None, compile="disabled", attention_mode="sdpa"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@@ -396,10 +385,7 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["in_channels"] = 32
else:
transformer_config["in_channels"] = 33
- if pab_config is not None:
- transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config)
- else:
- transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
+ transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
elif "I2V" in model or "Interpolation" in model:
transformer_config["in_channels"] = 32
if "1_5" in model:
@@ -409,16 +395,10 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["patch_bias"] = False
transformer_config["sample_height"] = 96
transformer_config["sample_width"] = 170
- if pab_config is not None:
- transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
- else:
- transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
+ transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
else:
transformer_config["in_channels"] = 16
- if pab_config is not None:
- transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
- else:
- transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
+ transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
if "2b" in model:
@@ -476,13 +456,13 @@ class DownloadAndLoadCogVideoGGUFModel:
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
if "Pose" in model:
- pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
+ pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
else:
- pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
+ pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
else:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
- pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
+ pipe = CogVideoXPipeline(vae, transformer, scheduler)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
diff --git a/nodes.py b/nodes.py
index ecea9db..cd90f33 100644
--- a/nodes.py
+++ b/nodes.py
@@ -44,8 +44,6 @@ from PIL import Image
import numpy as np
import json
-
-
script_directory = os.path.dirname(os.path.abspath(__file__))
if not "CogVideo" in folder_paths.folder_names_and_paths:
@@ -53,61 +51,11 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
-#PAB
-from .videosys.pab import CogVideoXPABConfig
-
-class CogVideoPABConfig:
- @classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}),
- "spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
- "spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
- "spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
- "temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}),
- "temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
- "temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
- "temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
- "cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}),
- "cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
- "cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
- "cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
-
- "steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ),
- }
- }
-
- RETURN_TYPES = ("PAB_CONFIG",)
- RETURN_NAMES = ("pab_config", )
- FUNCTION = "config"
- CATEGORY = "CogVideoWrapper"
- DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use"
-
- def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
- temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
- cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps):
-
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
- pab_config = CogVideoXPABConfig(
- steps=steps,
- spatial_broadcast=spatial_broadcast,
- spatial_threshold=[spatial_threshold_end, spatial_threshold_start],
- spatial_range=spatial_range,
- temporal_broadcast=temporal_broadcast,
- temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
- temporal_range=temporal_range,
- cross_broadcast=cross_broadcast,
- cross_threshold=[cross_threshold_end, cross_threshold_start],
- cross_range=cross_range
- )
-
- return (pab_config, )
-
class CogVideoContextOptions:
@classmethod
def INPUT_TYPES(s):
return {"required": {
- "context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],),
+ "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],),
"context_frames": ("INT", {"default": 48, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
"context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
@@ -1152,9 +1100,6 @@ class CogVideoXFunSampler:
end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
# Load Sampler
- if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
- log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
- scheduler="CogVideoXDDIM"
scheduler_config = pipeline["scheduler_config"]
if scheduler in scheduler_mapping:
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
@@ -1282,7 +1227,7 @@ class CogVideoXFunControlSampler:
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
- control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
+ control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0,
samples=None, denoise_strength=1.0, context_options=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@@ -1306,9 +1251,6 @@ class CogVideoXFunControlSampler:
# Load Sampler
scheduler_config = pipeline["scheduler_config"]
- if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
- log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
- scheduler="CogVideoXDDIM"
if scheduler in scheduler_mapping:
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
pipe.scheduler = noise_scheduler
@@ -1427,7 +1369,6 @@ NODE_CLASS_MAPPINGS = {
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
"CogVideoXFunControlSampler": CogVideoXFunControlSampler,
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
- "CogVideoPABConfig": CogVideoPABConfig,
"CogVideoTransformerEdit": CogVideoTransformerEdit,
"CogVideoControlImageEncode": CogVideoControlImageEncode,
"CogVideoContextOptions": CogVideoContextOptions,
@@ -1450,7 +1391,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
"CogVideoXFunControlSampler": "CogVideoXFun Control Sampler",
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
- "CogVideoPABConfig": "CogVideo PABConfig",
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
"CogVideoContextOptions": "CogVideo Context Options",
diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py
index 87d19e9..09e9103 100644
--- a/pipeline_cogvideox.py
+++ b/pipeline_cogvideox.py
@@ -20,8 +20,8 @@ import torch
import torch.nn.functional as F
import math
-from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel
-#from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.models import AutoencoderKLCogVideoX
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
@@ -35,10 +35,6 @@ from comfy.utils import ProgressBar
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-from .videosys.core.pipeline import VideoSysPipeline
-from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
-from .videosys.core.pab_mgr import set_pab_manager
-
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
@@ -115,7 +111,7 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
-class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
+class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
@@ -144,10 +140,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
def __init__(
self,
vae: AutoencoderKLCogVideoX,
- transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB],
+ transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
original_mask = None,
- pab_config = None
):
super().__init__()
@@ -164,9 +159,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.video_processor.config.do_resize = False
- if pab_config is not None:
- set_pab_manager(pab_config)
-
self.input_with_padding = True
@@ -289,29 +281,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps.to(device), num_inference_steps - t_start
-
- def _gaussian_weights(self, t_tile_length, t_batch_size):
- from numpy import pi, exp, sqrt
-
- var = 0.01
- midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
- t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
- weights = torch.tensor(t_probs)
- weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
- return weights
-
- # def fuse_qkv_projections(self) -> None:
- # r"""Enables fused QKV projections."""
- # self.fusing_transformer = True
- # self.transformer.fuse_qkv_projections()
-
- # def unfuse_qkv_projections(self) -> None:
- # r"""Disable QKV projection fusion if enabled."""
- # if not self.fusing_transformer:
- # logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
- # else:
- # self.transformer.unfuse_qkv_projections()
- # self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
@@ -365,8 +334,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
height: int = 480,
width: int = 720,
num_frames: int = 48,
- t_tile_length: int = 12,
- t_tile_overlap: int = 4,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
@@ -487,9 +454,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
num_frames += self.additional_frames * self.vae_scale_factor_temporal
- #if latents is None and num_frames == t_tile_length:
- # num_frames += 1
-
if self.original_mask is not None:
image_latents = latents
original_image_latents = image_latents
@@ -569,23 +533,16 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7. context schedule and temporal tiling
- if context_schedule is not None and context_schedule == "temporal_tiling":
- t_tile_length = context_frames
- t_tile_overlap = context_overlap
- t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
- use_temporal_tiling = True
- logger.info("Temporal tiling enabled")
- elif context_schedule is not None:
+ if context_schedule is not None:
if image_cond_latents is not None:
raise NotImplementedError("Context schedule not currently supported with image conditioning")
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
- use_temporal_tiling = False
use_context_schedule = True
from .cogvideox_fun.context import get_context_scheduler
context = get_context_scheduler(context_schedule)
+ #todo ofs embeds?
else:
- use_temporal_tiling = False
use_context_schedule = False
logger.info("Temporal tiling and context schedule disabled")
# 7.5. Create rotary embeds if required
@@ -647,100 +604,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
for i, t in enumerate(timesteps):
if self.interrupt:
continue
- if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
- #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
- # =====================================================
- grid_ts = 0
- cur_t = 0
- while cur_t < latents.shape[1]:
- cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
- grid_ts += 1
-
- all_t = latents.shape[1]
- latents_all_list = []
- # =====================================================
-
- for t_i in range(grid_ts):
- if t_i < grid_ts - 1:
- ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
- if t_i == grid_ts - 1:
- ofs_t = all_t - t_tile_length
-
- input_start_t = ofs_t
- input_end_t = ofs_t + t_tile_length
-
- #latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
- #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
- image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
- if self.transformer.config.use_rotary_positional_embeddings
- else None
- )
-
- latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
- latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
- latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
-
- #t_input = t[None].to(device)
- t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
-
- # predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input_tile,
- encoder_hidden_states=prompt_embeds,
- timestep=t_input,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- )[0]
- noise_pred = noise_pred.float()
-
- if self.do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
- latents_all_list.append(latents_tile)
-
- # ==========================================
- latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
- contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
- # Add each tile contribution to overall latents
- for t_i in range(grid_ts):
- if t_i < grid_ts - 1:
- ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
- if t_i == grid_ts - 1:
- ofs_t = all_t - t_tile_length
-
- input_start_t = ofs_t
- input_end_t = ofs_t + t_tile_length
-
- latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
- contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
-
- latents_all /= contributors
-
- latents = latents_all
- #print("latents",latents.shape)
- # start diff diff
- if i < len(timesteps) - 1 and self.original_mask is not None:
- noise_timestep = timesteps[i + 1]
- image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
- )
- mask = mask.to(latents)
- ts_from = timesteps[0]
- ts_to = timesteps[-1]
- threshold = (t - ts_to) / (ts_from - ts_to)
- mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
- latents = image_latent * mask + latents * (1 - mask)
- # end diff diff
-
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- comfy_pbar.update(1)
- # ==========================================
- elif use_context_schedule:
+ # region context schedule sampling
+ if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
counter = torch.zeros_like(latent_model_input)
@@ -858,7 +723,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
-
+
+ # region sampling
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
diff --git a/videosys/cogvideox_transformer_3d.py b/videosys/cogvideox_transformer_3d.py
deleted file mode 100644
index 26550a2..0000000
--- a/videosys/cogvideox_transformer_3d.py
+++ /dev/null
@@ -1,621 +0,0 @@
-# Adapted from CogVideo
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# CogVideo: https://github.com/THUDM/CogVideo
-# diffusers: https://github.com/huggingface/diffusers
-# --------------------------------------------------------
-
-from typing import Any, Dict, Optional, Tuple, Union
-from einops import rearrange
-import torch
-import torch.nn.functional as F
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.attention import Attention, FeedForward
-from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed
-from diffusers.models.modeling_outputs import Transformer2DModelOutput
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils import is_torch_version
-from diffusers.utils.torch_utils import maybe_allow_in_graph
-from torch import nn
-
-from .core.pab_mgr import enable_pab, if_broadcast_spatial
-from .modules.embeddings import apply_rotary_emb
-
-#from .modules.embeddings import CogVideoXPatchEmbed
-
-from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
-try:
- from sageattention import sageattn
- SAGEATTN_IS_AVAVILABLE = True
-except:
- SAGEATTN_IS_AVAVILABLE = False
-
-class CogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
- @torch.compiler.disable()
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- attn_heads = attn.heads
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn_heads
-
- query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- emb_len = image_rotary_emb[0].shape[0]
- query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
- query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
- )
- if not attn.is_cross_attention:
- key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
- key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
- )
-
- if SAGEATTN_IS_AVAVILABLE:
- hidden_states = sageattn(query, key, value, is_causal=False)
- else:
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-
-class FusedCogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
- @torch.compiler.disable()
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
- if not attn.is_cross_attention:
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
- if SAGEATTN_IS_AVAVILABLE:
- hidden_states = sageattn(query, key, value, is_causal=False)
- else:
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-
-@maybe_allow_in_graph
-class CogVideoXBlock(nn.Module):
- r"""
- Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
-
- Parameters:
- dim (`int`):
- The number of channels in the input and output.
- num_attention_heads (`int`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`):
- The number of channels in each head.
- time_embed_dim (`int`):
- The number of channels in timestep embedding.
- dropout (`float`, defaults to `0.0`):
- The dropout probability to use.
- activation_fn (`str`, defaults to `"gelu-approximate"`):
- Activation function to be used in feed-forward.
- attention_bias (`bool`, defaults to `False`):
- Whether or not to use bias in attention projection layers.
- qk_norm (`bool`, defaults to `True`):
- Whether or not to use normalization after query and key projections in Attention.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether to use learnable elementwise affine parameters for normalization.
- norm_eps (`float`, defaults to `1e-5`):
- Epsilon value for normalization layers.
- final_dropout (`bool` defaults to `False`):
- Whether to apply a final dropout after the last feed-forward layer.
- ff_inner_dim (`int`, *optional*, defaults to `None`):
- Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
- ff_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Feed-forward layer.
- attention_out_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Attention output projection layer.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- time_embed_dim: int,
- dropout: float = 0.0,
- activation_fn: str = "gelu-approximate",
- attention_bias: bool = False,
- qk_norm: bool = True,
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- final_dropout: bool = True,
- ff_inner_dim: Optional[int] = None,
- ff_bias: bool = True,
- attention_out_bias: bool = True,
- block_idx: int = 0,
- ):
- super().__init__()
-
- # 1. Self Attention
- self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.attn1 = Attention(
- query_dim=dim,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- qk_norm="layer_norm" if qk_norm else None,
- eps=1e-6,
- bias=attention_bias,
- out_bias=attention_out_bias,
- processor=CogVideoXAttnProcessor2_0(),
- )
-
- # parallel
- #self.attn1.parallel_manager = None
-
- # 2. Feed Forward
- self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- inner_dim=ff_inner_dim,
- bias=ff_bias,
- )
-
- # pab
- self.attn_count = 0
- self.last_attn = None
- self.block_idx = block_idx
- #@torch.compiler.disable()
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- temb: torch.Tensor,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- timestep=None,
- video_flow_feature: Optional[torch.Tensor] = None,
- fuser=None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- # norm & modulate
- norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
- hidden_states, encoder_hidden_states, temb
- )
- # Tora Motion-guidance Fuser
- if video_flow_feature is not None:
- H, W = video_flow_feature.shape[-2:]
- T = norm_hidden_states.shape[1] // H // W
- h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
- h = fuser(h, video_flow_feature.to(h), T=T)
- norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
- del h, fuser
- # attention
- if enable_pab():
- broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
- if enable_pab() and broadcast_attn:
- attn_hidden_states, attn_encoder_hidden_states = self.last_attn
- else:
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=norm_encoder_hidden_states,
- image_rotary_emb=image_rotary_emb,
- )
- if enable_pab():
- self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
-
- hidden_states = hidden_states + gate_msa * attn_hidden_states
- encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
-
- # norm & modulate
- norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
- hidden_states, encoder_hidden_states, temb
- )
-
- # feed-forward
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
- ff_output = self.ff(norm_hidden_states)
-
- hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
- encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
-
- return hidden_states, encoder_hidden_states
-
-
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
- """
- A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
-
- Parameters:
- num_attention_heads (`int`, defaults to `30`):
- The number of heads to use for multi-head attention.
- attention_head_dim (`int`, defaults to `64`):
- The number of channels in each head.
- in_channels (`int`, defaults to `16`):
- The number of channels in the input.
- out_channels (`int`, *optional*, defaults to `16`):
- The number of channels in the output.
- flip_sin_to_cos (`bool`, defaults to `True`):
- Whether to flip the sin to cos in the time embedding.
- time_embed_dim (`int`, defaults to `512`):
- Output dimension of timestep embeddings.
- text_embed_dim (`int`, defaults to `4096`):
- Input dimension of text embeddings from the text encoder.
- num_layers (`int`, defaults to `30`):
- The number of layers of Transformer blocks to use.
- dropout (`float`, defaults to `0.0`):
- The dropout probability to use.
- attention_bias (`bool`, defaults to `True`):
- Whether or not to use bias in the attention projection layers.
- sample_width (`int`, defaults to `90`):
- The width of the input latents.
- sample_height (`int`, defaults to `60`):
- The height of the input latents.
- sample_frames (`int`, defaults to `49`):
- The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
- instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
- but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
- K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
- patch_size (`int`, defaults to `2`):
- The size of the patches to use in the patch embedding layer.
- temporal_compression_ratio (`int`, defaults to `4`):
- The compression ratio across the temporal dimension. See documentation for `sample_frames`.
- max_text_seq_length (`int`, defaults to `226`):
- The maximum sequence length of the input text embeddings.
- activation_fn (`str`, defaults to `"gelu-approximate"`):
- Activation function to use in feed-forward.
- timestep_activation_fn (`str`, defaults to `"silu"`):
- Activation function to use when generating the timestep embeddings.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether or not to use elementwise affine in normalization layers.
- norm_eps (`float`, defaults to `1e-5`):
- The epsilon value to use in normalization layers.
- spatial_interpolation_scale (`float`, defaults to `1.875`):
- Scaling factor to apply in 3D positional embeddings across spatial dimensions.
- temporal_interpolation_scale (`float`, defaults to `1.0`):
- Scaling factor to apply in 3D positional embeddings across temporal dimensions.
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- num_attention_heads: int = 30,
- attention_head_dim: int = 64,
- in_channels: int = 16,
- out_channels: Optional[int] = 16,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- time_embed_dim: int = 512,
- text_embed_dim: int = 4096,
- num_layers: int = 30,
- dropout: float = 0.0,
- attention_bias: bool = True,
- sample_width: int = 90,
- sample_height: int = 60,
- sample_frames: int = 49,
- patch_size: int = 2,
- temporal_compression_ratio: int = 4,
- max_text_seq_length: int = 226,
- activation_fn: str = "gelu-approximate",
- timestep_activation_fn: str = "silu",
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- spatial_interpolation_scale: float = 1.875,
- temporal_interpolation_scale: float = 1.0,
- use_rotary_positional_embeddings: bool = False,
- use_learned_positional_embeddings: bool = False,
- ):
- super().__init__()
- inner_dim = num_attention_heads * attention_head_dim
-
- post_patch_height = sample_height // patch_size
- post_patch_width = sample_width // patch_size
- post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
- self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
-
- # 1. Patch embedding
- self.patch_embed = CogVideoXPatchEmbed(
- patch_size=patch_size,
- in_channels=in_channels,
- embed_dim=inner_dim,
- text_embed_dim=text_embed_dim,
- bias=True,
- sample_width=sample_width,
- sample_height=sample_height,
- sample_frames=sample_frames,
- temporal_compression_ratio=temporal_compression_ratio,
- max_text_seq_length=max_text_seq_length,
- spatial_interpolation_scale=spatial_interpolation_scale,
- temporal_interpolation_scale=temporal_interpolation_scale,
- use_positional_embeddings=not use_rotary_positional_embeddings,
- use_learned_positional_embeddings=use_learned_positional_embeddings,
- )
- self.embedding_dropout = nn.Dropout(dropout)
-
- # 2. 3D positional embeddings
- spatial_pos_embedding = get_3d_sincos_pos_embed(
- inner_dim,
- (post_patch_width, post_patch_height),
- post_time_compression_frames,
- spatial_interpolation_scale,
- temporal_interpolation_scale,
- )
- spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
- pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
- pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
- self.register_buffer("pos_embedding", pos_embedding, persistent=False)
-
- # 3. Time embeddings
- self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
-
- # 4. Define spatio-temporal transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- CogVideoXBlock(
- dim=inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- time_embed_dim=time_embed_dim,
- dropout=dropout,
- activation_fn=activation_fn,
- attention_bias=attention_bias,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- )
- for _ in range(num_layers)
- ]
- )
- self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
-
- # 5. Output blocks
- self.norm_out = AdaLayerNorm(
- embedding_dim=time_embed_dim,
- output_dim=2 * inner_dim,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- chunk_dim=1,
- )
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
-
- self.gradient_checkpointing = False
-
- self.fuser_list = None
-
- # parallel
- #self.parallel_manager = None
-
- # def enable_parallel(self, dp_size, sp_size, enable_cp):
- # # update cfg parallel
- # if enable_cp and sp_size % 2 == 0:
- # sp_size = sp_size // 2
- # cp_size = 2
- # else:
- # cp_size = 1
-
- # self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
-
- # for _, module in self.named_modules():
- # if hasattr(module, "parallel_manager"):
- # module.parallel_manager = self.parallel_manager
-
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- timestep: Union[int, float, torch.LongTensor],
- timestep_cond: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- return_dict: bool = True,
- controlnet_states: torch.Tensor = None,
- controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
- video_flow_features: Optional[torch.Tensor] = None,
- ):
- # if self.parallel_manager.cp_size > 1:
- # (
- # hidden_states,
- # encoder_hidden_states,
- # timestep,
- # timestep_cond,
- # image_rotary_emb,
- # ) = batch_func(
- # partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
- # hidden_states,
- # encoder_hidden_states,
- # timestep,
- # timestep_cond,
- # image_rotary_emb,
- # )
-
- batch_size, num_frames, channels, height, width = hidden_states.shape
-
- # 1. Time embedding
- timesteps = timestep
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=hidden_states.dtype)
- emb = self.time_embedding(t_emb, timestep_cond)
-
- # 2. Patch embedding
- hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
-
- # 3. Position embedding
- text_seq_length = encoder_hidden_states.shape[1]
- if not self.config.use_rotary_positional_embeddings:
- seq_length = height * width * num_frames // (self.config.patch_size**2)
-
- pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
- hidden_states = hidden_states + pos_embeds
- hidden_states = self.embedding_dropout(hidden_states)
-
- encoder_hidden_states = hidden_states[:, :text_seq_length]
- hidden_states = hidden_states[:, text_seq_length:]
-
- # if self.parallel_manager.sp_size > 1:
- # set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
- # hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
-
- # 4. Transformer blocks
- for i, block in enumerate(self.transformer_blocks):
- hidden_states, encoder_hidden_states = block(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=emb,
- image_rotary_emb=image_rotary_emb,
- timestep=timesteps if enable_pab() else None,
- video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
- fuser = self.fuser_list[i] if self.fuser_list is not None else None,
- )
- if (controlnet_states is not None) and (i < len(controlnet_states)):
- controlnet_states_block = controlnet_states[i]
- controlnet_block_weight = 1.0
- if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
- controlnet_block_weight = controlnet_weights[i]
- elif isinstance(controlnet_weights, (float, int)):
- controlnet_block_weight = controlnet_weights
-
- hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
-
- #if self.parallel_manager.sp_size > 1:
- # hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
-
- if not self.config.use_rotary_positional_embeddings:
- # CogVideoX-2B
- hidden_states = self.norm_final(hidden_states)
- else:
- # CogVideoX-5B
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- hidden_states = self.norm_final(hidden_states)
- hidden_states = hidden_states[:, text_seq_length:]
-
- # 5. Final block
- hidden_states = self.norm_out(hidden_states, temb=emb)
- hidden_states = self.proj_out(hidden_states)
-
- # 6. Unpatchify
- p = self.config.patch_size
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
-
- #if self.parallel_manager.cp_size > 1:
- # output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
-
- if not return_dict:
- return (output,)
- return Transformer2DModelOutput(sample=output)
diff --git a/videosys/core/__init__.py b/videosys/core/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/videosys/core/pab_mgr.py b/videosys/core/pab_mgr.py
deleted file mode 100644
index 6f19a50..0000000
--- a/videosys/core/pab_mgr.py
+++ /dev/null
@@ -1,232 +0,0 @@
-
-PAB_MANAGER = None
-
-
-class PABConfig:
- def __init__(
- self,
- steps: int,
- cross_broadcast: bool = False,
- cross_threshold: list = None,
- cross_range: int = None,
- spatial_broadcast: bool = False,
- spatial_threshold: list = None,
- spatial_range: int = None,
- temporal_broadcast: bool = False,
- temporal_threshold: list = None,
- temporal_range: int = None,
- mlp_broadcast: bool = False,
- mlp_spatial_broadcast_config: dict = None,
- mlp_temporal_broadcast_config: dict = None,
- ):
- self.steps = steps
-
- self.cross_broadcast = cross_broadcast
- self.cross_threshold = cross_threshold
- self.cross_range = cross_range
-
- self.spatial_broadcast = spatial_broadcast
- self.spatial_threshold = spatial_threshold
- self.spatial_range = spatial_range
-
- self.temporal_broadcast = temporal_broadcast
- self.temporal_threshold = temporal_threshold
- self.temporal_range = temporal_range
-
- self.mlp_broadcast = mlp_broadcast
- self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
- self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
- self.mlp_temporal_outputs = {}
- self.mlp_spatial_outputs = {}
-
-
-class PABManager:
- def __init__(self, config: PABConfig):
- self.config: PABConfig = config
-
- init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
- init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
- init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
- init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
- init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
- print(init_prompt)
-
- def if_broadcast_cross(self, timestep: int, count: int):
- if (
- self.config.cross_broadcast
- and (timestep is not None)
- and (count % self.config.cross_range != 0)
- and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
- ):
- flag = True
- else:
- flag = False
- count = (count + 1) % self.config.steps
- return flag, count
-
- def if_broadcast_temporal(self, timestep: int, count: int):
- if (
- self.config.temporal_broadcast
- and (timestep is not None)
- and (count % self.config.temporal_range != 0)
- and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
- ):
- flag = True
- else:
- flag = False
- count = (count + 1) % self.config.steps
- return flag, count
-
- def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
- if (
- self.config.spatial_broadcast
- and (timestep is not None)
- and (count % self.config.spatial_range != 0)
- and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
- ):
- flag = True
- else:
- flag = False
- count = (count + 1) % self.config.steps
- return flag, count
-
- @staticmethod
- def _is_t_in_skip_config(all_timesteps, timestep, config):
- is_t_in_skip_config = False
- skip_range = None
- for key in config:
- if key not in all_timesteps:
- continue
- index = all_timesteps.index(key)
- skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
- if timestep in skip_range:
- is_t_in_skip_config = True
- skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
- break
- return is_t_in_skip_config, skip_range
-
- def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
- if not self.config.mlp_broadcast:
- return False, None, False, None
-
- if is_temporal:
- cur_config = self.config.mlp_temporal_broadcast_config
- else:
- cur_config = self.config.mlp_spatial_broadcast_config
-
- is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
- next_flag = False
- if (
- self.config.mlp_broadcast
- and (timestep is not None)
- and (timestep in cur_config)
- and (block_idx in cur_config[timestep]["block"])
- ):
- flag = False
- next_flag = True
- count = count + 1
- elif (
- self.config.mlp_broadcast
- and (timestep is not None)
- and (is_t_in_skip_config)
- and (block_idx in cur_config[skip_range[0]]["block"])
- ):
- flag = True
- count = 0
- else:
- flag = False
-
- return flag, count, next_flag, skip_range
-
- def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
- if is_temporal:
- self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
- else:
- self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
-
- def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
- skip_start_t = skip_range[0]
- if is_temporal:
- skip_output = (
- self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
- if self.config.mlp_temporal_outputs is not None
- else None
- )
- else:
- skip_output = (
- self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
- if self.config.mlp_spatial_outputs is not None
- else None
- )
-
- if skip_output is not None:
- if timestep == skip_range[-1]:
- # TODO: save memory
- if is_temporal:
- del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
- else:
- del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
- else:
- raise ValueError(
- f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
- )
-
- return skip_output
-
- def get_spatial_mlp_outputs(self):
- return self.config.mlp_spatial_outputs
-
- def get_temporal_mlp_outputs(self):
- return self.config.mlp_temporal_outputs
-
-
-def set_pab_manager(config: PABConfig):
- global PAB_MANAGER
- PAB_MANAGER = PABManager(config)
-
-
-def enable_pab():
- if PAB_MANAGER is None:
- return False
- return (
- PAB_MANAGER.config.cross_broadcast
- or PAB_MANAGER.config.spatial_broadcast
- or PAB_MANAGER.config.temporal_broadcast
- )
-
-
-def update_steps(steps: int):
- if PAB_MANAGER is not None:
- PAB_MANAGER.config.steps = steps
-
-
-def if_broadcast_cross(timestep: int, count: int):
- if not enable_pab():
- return False, count
- return PAB_MANAGER.if_broadcast_cross(timestep, count)
-
-
-def if_broadcast_temporal(timestep: int, count: int):
- if not enable_pab():
- return False, count
- return PAB_MANAGER.if_broadcast_temporal(timestep, count)
-
-
-def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
- if not enable_pab():
- return False, count
- return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
-
-
-def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
- if not enable_pab():
- return False, count
- return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
-
-
-def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
- return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
-
-
-def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
- return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
diff --git a/videosys/core/pipeline.py b/videosys/core/pipeline.py
deleted file mode 100644
index 3244749..0000000
--- a/videosys/core/pipeline.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import inspect
-from abc import abstractmethod
-
-import torch
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
-
-class VideoSysPipeline(DiffusionPipeline):
- def __init__(self):
- super().__init__()
-
- @staticmethod
- def set_eval_and_device(device: torch.device, *modules):
- for module in modules:
- module.eval()
- module.to(device)
-
- @abstractmethod
- def generate(self, *args, **kwargs):
- pass
-
- def __call__(self, *args, **kwargs):
- """
- In diffusers, it is a convention to call the pipeline object.
- But in VideoSys, we will use the generate method for better prompt.
- This is a wrapper for the generate method to support the diffusers usage.
- """
- return self.generate(*args, **kwargs)
-
- @classmethod
- def _get_signature_keys(cls, obj):
- parameters = inspect.signature(obj.__init__).parameters
- required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
- optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
- expected_modules = set(required_parameters.keys()) - {"self"}
- # modify: remove the config module from the expected modules
- expected_modules = expected_modules - {"config"}
-
- optional_names = list(optional_parameters)
- for name in optional_names:
- if name in cls._optional_components:
- expected_modules.add(name)
- optional_parameters.remove(name)
-
- return expected_modules, optional_parameters
diff --git a/videosys/modules/__init__.py b/videosys/modules/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/videosys/modules/activations.py b/videosys/modules/activations.py
deleted file mode 100644
index cf24149..0000000
--- a/videosys/modules/activations.py
+++ /dev/null
@@ -1,3 +0,0 @@
-import torch.nn as nn
-
-approx_gelu = lambda: nn.GELU(approximate="tanh")
diff --git a/videosys/modules/downsampling.py b/videosys/modules/downsampling.py
deleted file mode 100644
index 9455a32..0000000
--- a/videosys/modules/downsampling.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class CogVideoXDownsample3D(nn.Module):
- # Todo: Wait for paper relase.
- r"""
- A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
-
- Args:
- in_channels (`int`):
- Number of channels in the input image.
- out_channels (`int`):
- Number of channels produced by the convolution.
- kernel_size (`int`, defaults to `3`):
- Size of the convolving kernel.
- stride (`int`, defaults to `2`):
- Stride of the convolution.
- padding (`int`, defaults to `0`):
- Padding added to all four sides of the input.
- compress_time (`bool`, defaults to `False`):
- Whether or not to compress the time dimension.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 2,
- padding: int = 0,
- compress_time: bool = False,
- ):
- super().__init__()
-
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
- self.compress_time = compress_time
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.compress_time:
- batch_size, channels, frames, height, width = x.shape
-
- # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
- x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
-
- if x.shape[-1] % 2 == 1:
- x_first, x_rest = x[..., 0], x[..., 1:]
- if x_rest.shape[-1] > 0:
- # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
- x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
-
- x = torch.cat([x_first[..., None], x_rest], dim=-1)
- # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
- x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
- else:
- # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
- x = F.avg_pool1d(x, kernel_size=2, stride=2)
- # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
- x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
-
- # Pad the tensor
- pad = (0, 1, 0, 1)
- x = F.pad(x, pad, mode="constant", value=0)
- batch_size, channels, frames, height, width = x.shape
- # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
- x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
- x = self.conv(x)
- # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
- x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
- return x
diff --git a/videosys/modules/embeddings.py b/videosys/modules/embeddings.py
deleted file mode 100644
index 04eba82..0000000
--- a/videosys/modules/embeddings.py
+++ /dev/null
@@ -1,308 +0,0 @@
-import math
-from typing import Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from einops import rearrange
-
-class CogVideoXPatchEmbed(nn.Module):
- def __init__(
- self,
- patch_size: int = 2,
- in_channels: int = 16,
- embed_dim: int = 1920,
- text_embed_dim: int = 4096,
- bias: bool = True,
- ) -> None:
- super().__init__()
- self.patch_size = patch_size
-
- self.proj = nn.Conv2d(
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
- )
- self.text_proj = nn.Linear(text_embed_dim, embed_dim)
-
- def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
- r"""
- Args:
- text_embeds (`torch.Tensor`):
- Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
- image_embeds (`torch.Tensor`):
- Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
- """
- text_embeds = self.text_proj(text_embeds)
-
- batch, num_frames, channels, height, width = image_embeds.shape
- image_embeds = image_embeds.reshape(-1, channels, height, width)
- image_embeds = self.proj(image_embeds)
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
-
- embeds = torch.cat(
- [text_embeds, image_embeds], dim=1
- ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
- return embeds
-
-
-class OpenSoraPatchEmbed3D(nn.Module):
- """Video to Patch Embedding.
-
- Args:
- patch_size (int): Patch token size. Default: (2,4,4).
- in_chans (int): Number of input video channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self,
- patch_size=(2, 4, 4),
- in_chans=3,
- embed_dim=96,
- norm_layer=None,
- flatten=True,
- ):
- super().__init__()
- self.patch_size = patch_size
- self.flatten = flatten
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- """Forward function."""
- # padding
- _, _, D, H, W = x.size()
- if W % self.patch_size[2] != 0:
- x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
- if H % self.patch_size[1] != 0:
- x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
- if D % self.patch_size[0] != 0:
- x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
-
- x = self.proj(x) # (B C T H W)
- if self.norm is not None:
- D, Wh, Ww = x.size(2), x.size(3), x.size(4)
- x = x.flatten(2).transpose(1, 2)
- x = self.norm(x)
- x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
- if self.flatten:
- x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
- return x
-
-
-class TimestepEmbedder(nn.Module):
- """
- Embeds scalar timesteps into vector representations.
- """
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10000):
- """
- Create sinusoidal timestep embeddings.
- :param t: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an (N, D) Tensor of positional embeddings.
- """
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
- half = dim // 2
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
- freqs = freqs.to(device=t.device)
- args = t[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t, dtype):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- if t_freq.dtype != dtype:
- t_freq = t_freq.to(dtype)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-
-class SizeEmbedder(TimestepEmbedder):
- """
- Embeds scalar timesteps into vector representations.
- """
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
- self.outdim = hidden_size
-
- def forward(self, s, bs):
- if s.ndim == 1:
- s = s[:, None]
- assert s.ndim == 2
- if s.shape[0] != bs:
- s = s.repeat(bs // s.shape[0], 1)
- assert s.shape[0] == bs
- b, dims = s.shape[0], s.shape[1]
- s = rearrange(s, "b d -> (b d)")
- s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
- s_emb = self.mlp(s_freq)
- s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
- return s_emb
-
- @property
- def dtype(self):
- return next(self.parameters()).dtype
-
-def get_3d_rotary_pos_embed(
- embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
-) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """
- RoPE for video tokens with 3D structure.
-
- Args:
- embed_dim: (`int`):
- The embedding dimension size, corresponding to hidden_size_head.
- crops_coords (`Tuple[int]`):
- The top-left and bottom-right coordinates of the crop.
- grid_size (`Tuple[int]`):
- The grid size of the spatial positional embedding (height, width).
- temporal_size (`int`):
- The size of the temporal dimension.
- theta (`float`):
- Scaling factor for frequency computation.
- use_real (`bool`):
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
-
- Returns:
- `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
- """
- start, stop = crops_coords
- grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
- grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
- grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
-
- # Compute dimensions for each axis
- dim_t = embed_dim // 4
- dim_h = embed_dim // 8 * 3
- dim_w = embed_dim // 8 * 3
-
- # Temporal frequencies
- freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
- grid_t = torch.from_numpy(grid_t).float()
- freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
- freqs_t = freqs_t.repeat_interleave(2, dim=-1)
-
- # Spatial frequencies for height and width
- freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
- freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
- grid_h = torch.from_numpy(grid_h).float()
- grid_w = torch.from_numpy(grid_w).float()
- freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
- freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
- freqs_h = freqs_h.repeat_interleave(2, dim=-1)
- freqs_w = freqs_w.repeat_interleave(2, dim=-1)
-
- # Broadcast and concatenate tensors along specified dimension
- def broadcast(tensors, dim=-1):
- num_tensors = len(tensors)
- shape_lens = {len(t.shape) for t in tensors}
- assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
- shape_len = list(shape_lens)[0]
- dim = (dim + shape_len) if dim < 0 else dim
- dims = list(zip(*(list(t.shape) for t in tensors)))
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
- assert all(
- [*(len(set(t[1])) <= 2 for t in expandable_dims)]
- ), "invalid dimensions for broadcastable concatenation"
- max_dims = [(t[0], max(t[1])) for t in expandable_dims]
- expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
- expanded_dims.insert(dim, (dim, dims[dim]))
- expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
- tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
- return torch.cat(tensors, dim=dim)
-
- freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
-
- t, h, w, d = freqs.shape
- freqs = freqs.view(t * h * w, d)
-
- # Generate sine and cosine components
- sin = freqs.sin()
- cos = freqs.cos()
-
- if use_real:
- return cos, sin
- else:
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
- return freqs_cis
-
-
-def apply_rotary_emb(
- x: torch.Tensor,
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
- use_real: bool = True,
- use_real_unbind_dim: int = -1,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
- tensors contain rotary embeddings and are returned as real tensors.
-
- Args:
- x (`torch.Tensor`):
- Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
- """
- if use_real:
- cos, sin = freqs_cis # [S, D]
- cos = cos[None, None]
- sin = sin[None, None]
- cos, sin = cos.to(x.device), sin.to(x.device)
-
- if use_real_unbind_dim == -1:
- # Use for example in Lumina
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
- elif use_real_unbind_dim == -2:
- # Use for example in Stable Audio
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
- else:
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
-
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
-
- return out
- else:
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
- freqs_cis = freqs_cis.unsqueeze(2)
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
-
- return x_out.type_as(x)
diff --git a/videosys/modules/normalization.py b/videosys/modules/normalization.py
deleted file mode 100644
index 216d0cc..0000000
--- a/videosys/modules/normalization.py
+++ /dev/null
@@ -1,85 +0,0 @@
-from typing import Optional, Tuple
-
-import torch
-import torch.nn as nn
-
-
-class CogVideoXLayerNormZero(nn.Module):
- def __init__(
- self,
- conditioning_dim: int,
- embedding_dim: int,
- elementwise_affine: bool = True,
- eps: float = 1e-5,
- bias: bool = True,
- ) -> None:
- super().__init__()
-
- self.silu = nn.SiLU()
- self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
- self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
-
- def forward(
- self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
- hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
- encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
- return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
-
-
-class AdaLayerNorm(nn.Module):
- r"""
- Norm layer modified to incorporate timestep embeddings.
-
- Parameters:
- embedding_dim (`int`): The size of each embedding vector.
- num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
- output_dim (`int`, *optional*):
- norm_elementwise_affine (`bool`, defaults to `False):
- norm_eps (`bool`, defaults to `False`):
- chunk_dim (`int`, defaults to `0`):
- """
-
- def __init__(
- self,
- embedding_dim: int,
- num_embeddings: Optional[int] = None,
- output_dim: Optional[int] = None,
- norm_elementwise_affine: bool = False,
- norm_eps: float = 1e-5,
- chunk_dim: int = 0,
- ):
- super().__init__()
-
- self.chunk_dim = chunk_dim
- output_dim = output_dim or embedding_dim * 2
-
- if num_embeddings is not None:
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
- else:
- self.emb = None
-
- self.silu = nn.SiLU()
- self.linear = nn.Linear(embedding_dim, output_dim)
- self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
-
- def forward(
- self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
- ) -> torch.Tensor:
- if self.emb is not None:
- temb = self.emb(timestep)
-
- temb = self.linear(self.silu(temb))
-
- if self.chunk_dim == 1:
- # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
- # other if-branch. This branch is specific to CogVideoX for now.
- shift, scale = temb.chunk(2, dim=1)
- shift = shift[:, None, :]
- scale = scale[:, None, :]
- else:
- scale, shift = temb.chunk(2, dim=0)
-
- x = self.norm(x) * (1 + scale) + shift
- return x
diff --git a/videosys/modules/upsampling.py b/videosys/modules/upsampling.py
deleted file mode 100644
index f9a61b7..0000000
--- a/videosys/modules/upsampling.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class CogVideoXUpsample3D(nn.Module):
- r"""
- A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
-
- Args:
- in_channels (`int`):
- Number of channels in the input image.
- out_channels (`int`):
- Number of channels produced by the convolution.
- kernel_size (`int`, defaults to `3`):
- Size of the convolving kernel.
- stride (`int`, defaults to `1`):
- Stride of the convolution.
- padding (`int`, defaults to `1`):
- Padding added to all four sides of the input.
- compress_time (`bool`, defaults to `False`):
- Whether or not to compress the time dimension.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 1,
- padding: int = 1,
- compress_time: bool = False,
- ) -> None:
- super().__init__()
-
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
- self.compress_time = compress_time
-
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- if self.compress_time:
- if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
- # split first frame
- x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
-
- x_first = F.interpolate(x_first, scale_factor=2.0)
- x_rest = F.interpolate(x_rest, scale_factor=2.0)
- x_first = x_first[:, :, None, :, :]
- inputs = torch.cat([x_first, x_rest], dim=2)
- elif inputs.shape[2] > 1:
- inputs = F.interpolate(inputs, scale_factor=2.0)
- else:
- inputs = inputs.squeeze(2)
- inputs = F.interpolate(inputs, scale_factor=2.0)
- inputs = inputs[:, :, None, :, :]
- else:
- # only interpolate 2D
- b, c, t, h, w = inputs.shape
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
- inputs = F.interpolate(inputs, scale_factor=2.0)
- inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
-
- b, c, t, h, w = inputs.shape
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
- inputs = self.conv(inputs)
- inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
-
- return inputs
diff --git a/videosys/pab.py b/videosys/pab.py
deleted file mode 100644
index 007e1b3..0000000
--- a/videosys/pab.py
+++ /dev/null
@@ -1,64 +0,0 @@
-class PABConfig:
- def __init__(
- self,
- steps: int,
- cross_broadcast: bool = False,
- cross_threshold: list = None,
- cross_range: int = None,
- spatial_broadcast: bool = False,
- spatial_threshold: list = None,
- spatial_range: int = None,
- temporal_broadcast: bool = False,
- temporal_threshold: list = None,
- temporal_range: int = None,
- mlp_broadcast: bool = False,
- mlp_spatial_broadcast_config: dict = None,
- mlp_temporal_broadcast_config: dict = None,
- ):
- self.steps = steps
-
- self.cross_broadcast = cross_broadcast
- self.cross_threshold = cross_threshold
- self.cross_range = cross_range
-
- self.spatial_broadcast = spatial_broadcast
- self.spatial_threshold = spatial_threshold
- self.spatial_range = spatial_range
-
- self.temporal_broadcast = temporal_broadcast
- self.temporal_threshold = temporal_threshold
- self.temporal_range = temporal_range
-
- self.mlp_broadcast = mlp_broadcast
- self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
- self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
- self.mlp_temporal_outputs = {}
- self.mlp_spatial_outputs = {}
-
-class CogVideoXPABConfig(PABConfig):
- def __init__(
- self,
- steps: int = 50,
- spatial_broadcast: bool = True,
- spatial_threshold: list = [100, 850],
- spatial_range: int = 2,
- temporal_broadcast: bool = False,
- temporal_threshold: list = [100, 850],
- temporal_range: int = 4,
- cross_broadcast: bool = False,
- cross_threshold: list = [100, 850],
- cross_range: int = 6,
- ):
- super().__init__(
- steps=steps,
- spatial_broadcast=spatial_broadcast,
- spatial_threshold=spatial_threshold,
- spatial_range=spatial_range,
- temporal_broadcast=temporal_broadcast,
- temporal_threshold=temporal_threshold,
- temporal_range=temporal_range,
- cross_broadcast=cross_broadcast,
- cross_threshold=cross_threshold,
- cross_range=cross_range
-
- )
\ No newline at end of file