From 0bd3da569ead6b36982dbeb55500a1c4963d137d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:54:52 +0200 Subject: [PATCH] code cleanup codebase getting too bloated: drop PAB support in favor of FasterCache drop temporal tilling in favor of FreeNoise --- cogvideox_fun/fun_pab_transformer_3d.py | 741 -------------------- cogvideox_fun/pipeline_cogvideox_control.py | 116 +-- cogvideox_fun/pipeline_cogvideox_inpaint.py | 126 +--- model_loading.py | 50 +- nodes.py | 64 +- pipeline_cogvideox.py | 154 +--- videosys/cogvideox_transformer_3d.py | 621 ---------------- videosys/core/__init__.py | 0 videosys/core/pab_mgr.py | 232 ------ videosys/core/pipeline.py | 44 -- videosys/modules/__init__.py | 0 videosys/modules/activations.py | 3 - videosys/modules/downsampling.py | 71 -- videosys/modules/embeddings.py | 308 -------- videosys/modules/normalization.py | 85 --- videosys/modules/upsampling.py | 67 -- videosys/pab.py | 64 -- 17 files changed, 35 insertions(+), 2711 deletions(-) delete mode 100644 cogvideox_fun/fun_pab_transformer_3d.py delete mode 100644 videosys/cogvideox_transformer_3d.py delete mode 100644 videosys/core/__init__.py delete mode 100644 videosys/core/pab_mgr.py delete mode 100644 videosys/core/pipeline.py delete mode 100644 videosys/modules/__init__.py delete mode 100644 videosys/modules/activations.py delete mode 100644 videosys/modules/downsampling.py delete mode 100644 videosys/modules/embeddings.py delete mode 100644 videosys/modules/normalization.py delete mode 100644 videosys/modules/upsampling.py delete mode 100644 videosys/pab.py diff --git a/cogvideox_fun/fun_pab_transformer_3d.py b/cogvideox_fun/fun_pab_transformer_3d.py deleted file mode 100644 index 25a3934..0000000 --- a/cogvideox_fun/fun_pab_transformer_3d.py +++ /dev/null @@ -1,741 +0,0 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Optional, Tuple, Union - -import os -import json -import torch -import glob -import torch.nn.functional as F -from torch import nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import is_torch_version, logging -from diffusers.utils.torch_utils import maybe_allow_in_graph -from diffusers.models.attention import Attention, FeedForward -#from diffusers.models.attention_processor import AttentionProcessor -from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.modeling_utils import ModelMixin -#from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero - -from ..videosys.modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero -from ..videosys.modules.embeddings import apply_rotary_emb -from ..videosys.core.pab_mgr import enable_pab, if_broadcast_spatial -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -try: - from sageattention import sageattn - SAGEATTN_IS_AVAVILABLE = True - logger.info("Using sageattn") -except: - logger.info("sageattn not found, using sdpa") - SAGEATTN_IS_AVAVILABLE = False - -class CogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - # if attn.parallel_manager.sp_size > 1: - # assert ( - # attn.heads % attn.parallel_manager.sp_size == 0 - # ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}" - # attn_heads = attn.heads // attn.parallel_manager.sp_size - # query, key, value = map( - # lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1), - # [query, key, value], - # ) - - attn_heads = attn.heads - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn_heads - - query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - emb_len = image_rotary_emb[0].shape[0] - query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb( - query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb - ) - if not attn.is_cross_attention: - key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb( - key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb - ) - - if SAGEATTN_IS_AVAVILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) - - #if attn.parallel_manager.sp_size > 1: - # hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - - -class FusedCogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - if SAGEATTN_IS_AVAVILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - -class CogVideoXPatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - embed_dim: int = 1920, - text_embed_dim: int = 4096, - bias: bool = True, - ) -> None: - super().__init__() - self.patch_size = patch_size - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - self.text_proj = nn.Linear(text_embed_dim, embed_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - r""" - Args: - text_embeds (`torch.Tensor`): - Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). - image_embeds (`torch.Tensor`): - Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). - """ - text_embeds = self.text_proj(text_embeds) - - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) - image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) - image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] - image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] - - embeds = torch.cat( - [text_embeds, image_embeds], dim=1 - ).contiguous() # [batch, seq_length + num_frames x height x width, channels] - return embeds - -@maybe_allow_in_graph -class CogVideoXBlock(nn.Module): - r""" - Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. - - Parameters: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - time_embed_dim (`int`): - The number of channels in timestep embedding. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to be used in feed-forward. - attention_bias (`bool`, defaults to `False`): - Whether or not to use bias in attention projection layers. - qk_norm (`bool`, defaults to `True`): - Whether or not to use normalization after query and key projections in Attention. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_eps (`float`, defaults to `1e-5`): - Epsilon value for normalization layers. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. - ff_inner_dim (`int`, *optional*, defaults to `None`): - Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. - ff_bias (`bool`, defaults to `True`): - Whether or not to use bias in Feed-forward layer. - attention_out_bias (`bool`, defaults to `True`): - Whether or not to use bias in Attention output projection layer. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - time_embed_dim: int, - dropout: float = 0.0, - activation_fn: str = "gelu-approximate", - attention_bias: bool = False, - qk_norm: bool = True, - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - final_dropout: bool = True, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - block_idx: int = 0, - ): - super().__init__() - - # 1. Self Attention - self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.attn1 = Attention( - query_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - qk_norm="layer_norm" if qk_norm else None, - eps=1e-6, - bias=attention_bias, - out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), - ) - - # parallel - #self.attn1.parallel_manager = None - - # 2. Feed Forward - self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - # pab - self.attn_count = 0 - self.last_attn = None - self.block_idx = block_idx - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - timestep=None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( - hidden_states, encoder_hidden_states, temb - ) - - # attention - if enable_pab(): - broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx) - if enable_pab() and broadcast_attn: - attn_hidden_states, attn_encoder_hidden_states = self.last_attn - else: - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - ) - if enable_pab(): - self.last_attn = (attn_hidden_states, attn_encoder_hidden_states) - - hidden_states = hidden_states + gate_msa * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( - hidden_states, encoder_hidden_states, temb - ) - - # feed-forward - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - ff_output = self.ff(norm_hidden_states) - - hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] - - return hidden_states, encoder_hidden_states - - -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): - """ - A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). - - Parameters: - num_attention_heads (`int`, defaults to `30`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, defaults to `64`): - The number of channels in each head. - in_channels (`int`, defaults to `16`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `16`): - The number of channels in the output. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - time_embed_dim (`int`, defaults to `512`): - Output dimension of timestep embeddings. - text_embed_dim (`int`, defaults to `4096`): - Input dimension of text embeddings from the text encoder. - num_layers (`int`, defaults to `30`): - The number of layers of Transformer blocks to use. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. - sample_width (`int`, defaults to `90`): - The width of the input latents. - sample_height (`int`, defaults to `60`): - The height of the input latents. - sample_frames (`int`, defaults to `49`): - The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 - instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, - but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with - K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). - patch_size (`int`, defaults to `2`): - The size of the patches to use in the patch embedding layer. - temporal_compression_ratio (`int`, defaults to `4`): - The compression ratio across the temporal dimension. See documentation for `sample_frames`. - max_text_seq_length (`int`, defaults to `226`): - The maximum sequence length of the input text embeddings. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to use in feed-forward. - timestep_activation_fn (`str`, defaults to `"silu"`): - Activation function to use when generating the timestep embeddings. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, defaults to `1e-5`): - The epsilon value to use in normalization layers. - spatial_interpolation_scale (`float`, defaults to `1.875`): - Scaling factor to apply in 3D positional embeddings across spatial dimensions. - temporal_interpolation_scale (`float`, defaults to `1.0`): - Scaling factor to apply in 3D positional embeddings across temporal dimensions. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 64, - in_channels: int = 16, - out_channels: Optional[int] = 16, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - time_embed_dim: int = 512, - text_embed_dim: int = 4096, - num_layers: int = 30, - dropout: float = 0.0, - attention_bias: bool = True, - sample_width: int = 90, - sample_height: int = 60, - sample_frames: int = 49, - patch_size: int = 2, - temporal_compression_ratio: int = 4, - max_text_seq_length: int = 226, - activation_fn: str = "gelu-approximate", - timestep_activation_fn: str = "silu", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - spatial_interpolation_scale: float = 1.875, - temporal_interpolation_scale: float = 1.0, - use_rotary_positional_embeddings: bool = False, - add_noise_in_inpaint_model: bool = False, - ): - super().__init__() - inner_dim = num_attention_heads * attention_head_dim - - post_patch_height = sample_height // patch_size - post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 - self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames - self.post_patch_height = post_patch_height - self.post_patch_width = post_patch_width - self.post_time_compression_frames = post_time_compression_frames - self.patch_size = patch_size - - # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) - self.embedding_dropout = nn.Dropout(dropout) - - # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, - ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) - self.register_buffer("pos_embedding", pos_embedding, persistent=False) - - # 3. Time embeddings - self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) - self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - - # 4. Define spatio-temporal transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - CogVideoXBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - time_embed_dim=time_embed_dim, - dropout=dropout, - activation_fn=activation_fn, - attention_bias=attention_bias, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for _ in range(num_layers) - ] - ) - self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - - # 5. Output blocks - self.norm_out = AdaLayerNorm( - embedding_dim=time_embed_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - chunk_dim=1, - ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) - - self.gradient_checkpointing = False - - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - inpaint_latents: Optional[torch.Tensor] = None, - control_latents: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - return_dict: bool = True, - ): - batch_size, num_frames, channels, height, width = hidden_states.shape - - # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - # 2. Patch embedding - if inpaint_latents is not None: - hidden_states = torch.concat([hidden_states, inpaint_latents], 2) - if control_latents is not None: - hidden_states = torch.concat([hidden_states, control_latents], 2) - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - - # 3. Position embedding - text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings: - seq_length = height * width * num_frames // (self.config.patch_size**2) - # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] - pos_embeds = self.pos_embedding - emb_size = hidden_states.size()[-1] - pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size) - pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3]) - pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False) - pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size) - pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1) - pos_embeds = pos_embeds[:, : text_seq_length + seq_length] - hidden_states = hidden_states + pos_embeds - hidden_states = self.embedding_dropout(hidden_states) - - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - # 4. Transformer blocks - - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - timestep=timestep, - ) - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 5. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 6. Unpatchify - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - - @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}): - if subfolder is not None: - pretrained_model_path = os.path.join(pretrained_model_path, subfolder) - print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") - - config_file = os.path.join(pretrained_model_path, 'config.json') - if not os.path.isfile(config_file): - raise RuntimeError(f"{config_file} does not exist") - with open(config_file, "r") as f: - config = json.load(f) - - from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **transformer_additional_kwargs) - model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) - model_file_safetensors = model_file.replace(".bin", ".safetensors") - if os.path.exists(model_file): - state_dict = torch.load(model_file, map_location="cpu") - elif os.path.exists(model_file_safetensors): - from safetensors.torch import load_file, safe_open - state_dict = load_file(model_file_safetensors) - else: - from safetensors.torch import load_file, safe_open - model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) - state_dict = {} - for model_file_safetensors in model_files_safetensors: - _state_dict = load_file(model_file_safetensors) - for key in _state_dict: - state_dict[key] = _state_dict[key] - - if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size(): - new_shape = model.state_dict()['patch_embed.proj.weight'].size() - if len(new_shape) == 5: - state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone() - state_dict['patch_embed.proj.weight'][:, :, :-1] = 0 - else: - if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: - model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight'] - model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0 - state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] - else: - model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :] - state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] - - tmp_state_dict = {} - for key in state_dict: - if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): - tmp_state_dict[key] = state_dict[key] - else: - print(key, "Size don't match, skip") - state_dict = tmp_state_dict - - m, u = model.load_state_dict(state_dict, strict=False) - print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") - print(m) - - params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()] - print(f"### Mamba Parameters: {sum(params) / 1e6} M") - - params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] - print(f"### attn1 Parameters: {sum(params) / 1e6} M") - - return model \ No newline at end of file diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 85687fe..f598147 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -33,10 +33,6 @@ from diffusers.video_processor import VideoProcessor from diffusers.image_processor import VaeImageProcessor from einops import rearrange -from ..videosys.core.pipeline import VideoSysPipeline -from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB -from ..videosys.core.pab_mgr import set_pab_manager - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -158,7 +154,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput): videos: torch.Tensor -class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): +class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline): r""" Pipeline for text-to-video generation using CogVideoX. @@ -188,7 +184,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], - pab_config = None ): super().__init__() @@ -210,9 +205,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) - if pab_config is not None: - set_pab_manager(pab_config) - def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None @@ -348,16 +340,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs - - def _gaussian_weights(self, t_tile_length, t_batch_size): - from numpy import pi, exp, sqrt - - var = 0.01 - midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1 - t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)] - weights = torch.tensor(t_probs) - weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) - return weights # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( @@ -697,24 +679,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - # 8.5. Temporal tiling prep - if context_schedule is not None and context_schedule == "temporal_tiling": - t_tile_length = context_frames - t_tile_overlap = context_overlap - t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) - use_temporal_tiling = True - print("Temporal tiling enabled") - elif context_schedule is not None: + if context_schedule is not None: print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") - use_temporal_tiling = False use_context_schedule = True from .context import get_context_scheduler context = get_context_scheduler(context_schedule) else: - use_temporal_tiling = False use_context_schedule = False - print("Temporal tiling and context schedule disabled") + print(" context schedule disabled") # 7. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) @@ -735,88 +708,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): for i, t in enumerate(timesteps): if self.interrupt: continue - - if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): - #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py - # ===================================================== - grid_ts = 0 - cur_t = 0 - while cur_t < latents.shape[1]: - cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length - grid_ts += 1 - - all_t = latents.shape[1] - latents_all_list = [] - # ===================================================== - - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, context_frames, device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length - - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length - - latents_tile = latents[:, input_start_t:input_end_t,:, :, :] - control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :] - - latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile - latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t) - - #t_input = t[None].to(device) - t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - - # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input_tile, - encoder_hidden_states=prompt_embeds, - timestep=t_input, - image_rotary_emb=image_rotary_emb, - return_dict=False, - control_latents=control_latents_tile, - )[0] - noise_pred = noise_pred.float() - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0] - latents_all_list.append(latents_tile) - - # ========================================== - latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) - contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) - # Add each tile contribution to overall latents - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length - - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length - - latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights - contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights - - latents_all /= contributors - - latents = latents_all - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - pbar.update(1) - # ========================================== - elif use_context_schedule: + if use_context_schedule: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index 4b3b4f3..7b9d8e7 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -33,11 +33,6 @@ from diffusers.video_processor import VideoProcessor from diffusers.image_processor import VaeImageProcessor from einops import rearrange -from ..videosys.core.pipeline import VideoSysPipeline -from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB -from ..videosys.core.pab_mgr import set_pab_manager - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -206,7 +201,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput): videos: torch.Tensor -class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): +class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline): r""" Pipeline for text-to-video generation using CogVideoX. @@ -236,7 +231,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], - pab_config = None ): super().__init__() @@ -258,9 +252,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) - if pab_config is not None: - set_pab_manager(pab_config) - def prepare_latents( self, batch_size, @@ -433,16 +424,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def _gaussian_weights(self, t_tile_length, t_batch_size): - from numpy import pi, exp, sqrt - - var = 0.01 - midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1 - t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)] - weights = torch.tensor(t_probs) - weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) - return weights - # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( self, @@ -866,22 +847,14 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Create rotary embeds if required - if context_schedule is not None and context_schedule == "temporal_tiling": - t_tile_length = context_frames - t_tile_overlap = context_overlap - t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) - use_temporal_tiling = True - print("Temporal tiling enabled") - elif context_schedule is not None: + if context_schedule is not None: print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") - use_temporal_tiling = False use_context_schedule = True from .context import get_context_scheduler context = get_context_scheduler(context_schedule) else: - use_temporal_tiling = False use_context_schedule = False - print("Temporal tiling and context schedule disabled") + print("context schedule disabled") # 7. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) @@ -915,87 +888,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): if self.interrupt: continue - if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): - #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py - # ===================================================== - grid_ts = 0 - cur_t = 0 - while cur_t < latents.shape[1]: - cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length - grid_ts += 1 - - all_t = latents.shape[1] - latents_all_list = [] - # ===================================================== - - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length - - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length - - latents_tile = latents[:, input_start_t:input_end_t,:, :, :] - inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :] - - latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile - latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t) - - #t_input = t[None].to(device) - t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - - # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input_tile, - encoder_hidden_states=prompt_embeds, - timestep=t_input, - image_rotary_emb=image_rotary_emb, - return_dict=False, - inpaint_latents=inpaint_latents_tile, - )[0] - noise_pred = noise_pred.float() - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0] - latents_all_list.append(latents_tile) - - # ========================================== - latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) - contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) - # Add each tile contribution to overall latents - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length - - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length - - latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights - contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights - - latents_all /= contributors - - latents = latents_all - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - pbar.update(1) - # ========================================== - elif use_context_schedule: + if use_context_schedule: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1133,18 +1026,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): else: pbar.update(1) - # if output_type == "numpy": - # video = self.decode_latents(latents) - # elif not output_type == "latent": - # video = self.decode_latents(latents) - # video = self.video_processor.postprocess_video(video=video, output_type=output_type) - # else: - # video = latents - # Offload all models self.maybe_free_model_hooks() - # if not return_dict: - # video = torch.from_numpy(video) - return latents \ No newline at end of file diff --git a/model_loading.py b/model_loading.py index fe9245b..532d6aa 100644 --- a/model_loading.py +++ b/model_loading.py @@ -12,15 +12,12 @@ from .pipeline_cogvideox import CogVideoXPipeline from contextlib import nullcontext from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun -from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFunPAB from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control -from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB - -from .utils import check_diffusers_version, remove_specific_blocks, log +from .utils import remove_specific_blocks, log from comfy.utils import load_torch_file script_directory = os.path.dirname(os.path.abspath(__file__)) @@ -95,7 +92,6 @@ class DownloadAndLoadCogVideoModel: "fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}), "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), - "pab_config": ("PAB_CONFIG", {"default": None}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), @@ -111,7 +107,7 @@ class DownloadAndLoadCogVideoModel: DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'" def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", - enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None, + enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None, attention_mode="sdpa", load_device="main_device"): if precision == "fp16" and "1.5" in model: @@ -188,15 +184,9 @@ class DownloadAndLoadCogVideoModel: # transformer if "Fun" in model: - if pab_config is not None: - transformer = CogVideoXTransformer3DModelFunPAB.from_pretrained(base_path, subfolder=subfolder) - else: - transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder) + transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder) else: - if pab_config is not None: - transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder=subfolder) - else: - transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) + transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) transformer = transformer.to(dtype).to(transformer_load_device) @@ -213,12 +203,12 @@ class DownloadAndLoadCogVideoModel: if "Fun" in model: vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) if "Pose" in model: - pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config) + pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler) else: - pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config) + pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) else: vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) - pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) + pipe = CogVideoXPipeline(vae, transformer, scheduler) if "cogvideox-2b-img2vid" in model: pipe.input_with_padding = False @@ -296,7 +286,7 @@ class DownloadAndLoadCogVideoModel: backend="nexfort", options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}}, ignores=["vae"], - fuse_qkv_projections=True if pab_config is None else False, + fuse_qkv_projections= False, ) pipeline = { @@ -334,7 +324,6 @@ class DownloadAndLoadCogVideoGGUFModel: "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), }, "optional": { - "pab_config": ("PAB_CONFIG", {"default": None}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), #"lora": ("COGLORA", {"default": None}), "compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), @@ -348,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel: CATEGORY = "CogVideoWrapper" def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, - pab_config=None, block_edit=None, compile="disabled", attention_mode="sdpa"): + block_edit=None, compile="disabled", attention_mode="sdpa"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -396,10 +385,7 @@ class DownloadAndLoadCogVideoGGUFModel: transformer_config["in_channels"] = 32 else: transformer_config["in_channels"] = 33 - if pab_config is not None: - transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config) - else: - transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config) + transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config) elif "I2V" in model or "Interpolation" in model: transformer_config["in_channels"] = 32 if "1_5" in model: @@ -409,16 +395,10 @@ class DownloadAndLoadCogVideoGGUFModel: transformer_config["patch_bias"] = False transformer_config["sample_height"] = 96 transformer_config["sample_width"] = 170 - if pab_config is not None: - transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config) - else: - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + transformer = CogVideoXTransformer3DModel.from_config(transformer_config) else: transformer_config["in_channels"] = 16 - if pab_config is not None: - transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config) - else: - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + transformer = CogVideoXTransformer3DModel.from_config(transformer_config) params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"} if "2b" in model: @@ -476,13 +456,13 @@ class DownloadAndLoadCogVideoGGUFModel: vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device) vae.load_state_dict(vae_sd) if "Pose" in model: - pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config) + pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler) else: - pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config) + pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) else: vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device) vae.load_state_dict(vae_sd) - pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) + pipe = CogVideoXPipeline(vae, transformer, scheduler) if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() diff --git a/nodes.py b/nodes.py index ecea9db..cd90f33 100644 --- a/nodes.py +++ b/nodes.py @@ -44,8 +44,6 @@ from PIL import Image import numpy as np import json - - script_directory = os.path.dirname(os.path.abspath(__file__)) if not "CogVideo" in folder_paths.folder_names_and_paths: @@ -53,61 +51,11 @@ if not "CogVideo" in folder_paths.folder_names_and_paths: if not "cogvideox_loras" in folder_paths.folder_names_and_paths: folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras")) -#PAB -from .videosys.pab import CogVideoXPABConfig - -class CogVideoPABConfig: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}), - "spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), - "spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), - "spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ), - "temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}), - "temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), - "temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), - "temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ), - "cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}), - "cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), - "cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), - "cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ), - - "steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ), - } - } - - RETURN_TYPES = ("PAB_CONFIG",) - RETURN_NAMES = ("pab_config", ) - FUNCTION = "config" - CATEGORY = "CogVideoWrapper" - DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use" - - def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range, - temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, - cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps): - - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - pab_config = CogVideoXPABConfig( - steps=steps, - spatial_broadcast=spatial_broadcast, - spatial_threshold=[spatial_threshold_end, spatial_threshold_start], - spatial_range=spatial_range, - temporal_broadcast=temporal_broadcast, - temporal_threshold=[temporal_threshold_end, temporal_threshold_start], - temporal_range=temporal_range, - cross_broadcast=cross_broadcast, - cross_threshold=[cross_threshold_end, cross_threshold_start], - cross_range=cross_range - ) - - return (pab_config, ) - class CogVideoContextOptions: @classmethod def INPUT_TYPES(s): return {"required": { - "context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],), + "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],), "context_frames": ("INT", {"default": 48, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), "context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), @@ -1152,9 +1100,6 @@ class CogVideoXFunSampler: end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None # Load Sampler - if context_options is not None and context_options["context_schedule"] == "temporal_tiling": - log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM") - scheduler="CogVideoXDDIM" scheduler_config = pipeline["scheduler_config"] if scheduler in scheduler_mapping: noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) @@ -1282,7 +1227,7 @@ class CogVideoXFunControlSampler: CATEGORY = "CogVideoWrapper" def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents, - control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8, + control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, samples=None, denoise_strength=1.0, context_options=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -1306,9 +1251,6 @@ class CogVideoXFunControlSampler: # Load Sampler scheduler_config = pipeline["scheduler_config"] - if context_options is not None and context_options["context_schedule"] == "temporal_tiling": - log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM") - scheduler="CogVideoXDDIM" if scheduler in scheduler_mapping: noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) pipe.scheduler = noise_scheduler @@ -1427,7 +1369,6 @@ NODE_CLASS_MAPPINGS = { "CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler, "CogVideoXFunControlSampler": CogVideoXFunControlSampler, "CogVideoTextEncodeCombine": CogVideoTextEncodeCombine, - "CogVideoPABConfig": CogVideoPABConfig, "CogVideoTransformerEdit": CogVideoTransformerEdit, "CogVideoControlImageEncode": CogVideoControlImageEncode, "CogVideoContextOptions": CogVideoContextOptions, @@ -1450,7 +1391,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler", "CogVideoXFunControlSampler": "CogVideoXFun Control Sampler", "CogVideoTextEncodeCombine": "CogVideo TextEncode Combine", - "CogVideoPABConfig": "CogVideo PABConfig", "CogVideoTransformerEdit": "CogVideo TransformerEdit", "CogVideoControlImageEncode": "CogVideo Control ImageEncode", "CogVideoContextOptions": "CogVideo Context Options", diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 87d19e9..09e9103 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -20,8 +20,8 @@ import torch import torch.nn.functional as F import math -from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel -#from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.models import AutoencoderKLCogVideoX +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor @@ -35,10 +35,6 @@ from comfy.utils import ProgressBar logger = logging.get_logger(__name__) # pylint: disable=invalid-name -from .videosys.core.pipeline import VideoSysPipeline -from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB -from .videosys.core.pab_mgr import set_pab_manager - def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): tw = tgt_width th = tgt_height @@ -115,7 +111,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps -class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for text-to-video generation using CogVideoX. @@ -144,10 +140,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): def __init__( self, vae: AutoencoderKLCogVideoX, - transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB], + transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], original_mask = None, - pab_config = None ): super().__init__() @@ -164,9 +159,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.video_processor.config.do_resize = False - if pab_config is not None: - set_pab_manager(pab_config) - self.input_with_padding = True @@ -289,29 +281,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps.to(device), num_inference_steps - t_start - - def _gaussian_weights(self, t_tile_length, t_batch_size): - from numpy import pi, exp, sqrt - - var = 0.01 - midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1 - t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)] - weights = torch.tensor(t_probs) - weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) - return weights - - # def fuse_qkv_projections(self) -> None: - # r"""Enables fused QKV projections.""" - # self.fusing_transformer = True - # self.transformer.fuse_qkv_projections() - - # def unfuse_qkv_projections(self) -> None: - # r"""Disable QKV projection fusion if enabled.""" - # if not self.fusing_transformer: - # logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") - # else: - # self.transformer.unfuse_qkv_projections() - # self.fusing_transformer = False def _prepare_rotary_positional_embeddings( self, @@ -365,8 +334,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): height: int = 480, width: int = 720, num_frames: int = 48, - t_tile_length: int = 12, - t_tile_overlap: int = 4, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -487,9 +454,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): num_frames += self.additional_frames * self.vae_scale_factor_temporal - #if latents is None and num_frames == t_tile_length: - # num_frames += 1 - if self.original_mask is not None: image_latents = latents original_image_latents = image_latents @@ -569,23 +533,16 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 7. context schedule and temporal tiling - if context_schedule is not None and context_schedule == "temporal_tiling": - t_tile_length = context_frames - t_tile_overlap = context_overlap - t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) - use_temporal_tiling = True - logger.info("Temporal tiling enabled") - elif context_schedule is not None: + if context_schedule is not None: if image_cond_latents is not None: raise NotImplementedError("Context schedule not currently supported with image conditioning") logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") - use_temporal_tiling = False use_context_schedule = True from .cogvideox_fun.context import get_context_scheduler context = get_context_scheduler(context_schedule) + #todo ofs embeds? else: - use_temporal_tiling = False use_context_schedule = False logger.info("Temporal tiling and context schedule disabled") # 7.5. Create rotary embeds if required @@ -647,100 +604,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): for i, t in enumerate(timesteps): if self.interrupt: continue - if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): - #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py - # ===================================================== - grid_ts = 0 - cur_t = 0 - while cur_t < latents.shape[1]: - cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length - grid_ts += 1 - - all_t = latents.shape[1] - latents_all_list = [] - # ===================================================== - - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length - - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length - - #latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - - latents_tile = latents[:, input_start_t:input_end_t,:, :, :] - latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile - latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t) - - #t_input = t[None].to(device) - t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - - # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input_tile, - encoder_hidden_states=prompt_embeds, - timestep=t_input, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] - noise_pred = noise_pred.float() - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0] - latents_all_list.append(latents_tile) - - # ========================================== - latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) - contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) - # Add each tile contribution to overall latents - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length - - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length - - latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights - contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights - - latents_all /= contributors - - latents = latents_all - #print("latents",latents.shape) - # start diff diff - if i < len(timesteps) - 1 and self.original_mask is not None: - noise_timestep = timesteps[i + 1] - image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep]) - ) - mask = mask.to(latents) - ts_from = timesteps[0] - ts_to = timesteps[-1] - threshold = (t - ts_to) / (ts_from - ts_to) - mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask)) - latents = image_latent * mask + latents * (1 - mask) - # end diff diff - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - comfy_pbar.update(1) - # ========================================== - elif use_context_schedule: + # region context schedule sampling + if use_context_schedule: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) counter = torch.zeros_like(latent_model_input) @@ -858,7 +723,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() comfy_pbar.update(1) - + + # region sampling else: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/videosys/cogvideox_transformer_3d.py b/videosys/cogvideox_transformer_3d.py deleted file mode 100644 index 26550a2..0000000 --- a/videosys/cogvideox_transformer_3d.py +++ /dev/null @@ -1,621 +0,0 @@ -# Adapted from CogVideo - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# CogVideo: https://github.com/THUDM/CogVideo -# diffusers: https://github.com/huggingface/diffusers -# -------------------------------------------------------- - -from typing import Any, Dict, Optional, Tuple, Union -from einops import rearrange -import torch -import torch.nn.functional as F -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.attention import Attention, FeedForward -from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils import is_torch_version -from diffusers.utils.torch_utils import maybe_allow_in_graph -from torch import nn - -from .core.pab_mgr import enable_pab, if_broadcast_spatial -from .modules.embeddings import apply_rotary_emb - -#from .modules.embeddings import CogVideoXPatchEmbed - -from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero -try: - from sageattention import sageattn - SAGEATTN_IS_AVAVILABLE = True -except: - SAGEATTN_IS_AVAVILABLE = False - -class CogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - @torch.compiler.disable() - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - attn_heads = attn.heads - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn_heads - - query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - emb_len = image_rotary_emb[0].shape[0] - query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb( - query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb - ) - if not attn.is_cross_attention: - key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb( - key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb - ) - - if SAGEATTN_IS_AVAVILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - - -class FusedCogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - @torch.compiler.disable() - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - if SAGEATTN_IS_AVAVILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - - -@maybe_allow_in_graph -class CogVideoXBlock(nn.Module): - r""" - Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. - - Parameters: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - time_embed_dim (`int`): - The number of channels in timestep embedding. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to be used in feed-forward. - attention_bias (`bool`, defaults to `False`): - Whether or not to use bias in attention projection layers. - qk_norm (`bool`, defaults to `True`): - Whether or not to use normalization after query and key projections in Attention. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_eps (`float`, defaults to `1e-5`): - Epsilon value for normalization layers. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. - ff_inner_dim (`int`, *optional*, defaults to `None`): - Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. - ff_bias (`bool`, defaults to `True`): - Whether or not to use bias in Feed-forward layer. - attention_out_bias (`bool`, defaults to `True`): - Whether or not to use bias in Attention output projection layer. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - time_embed_dim: int, - dropout: float = 0.0, - activation_fn: str = "gelu-approximate", - attention_bias: bool = False, - qk_norm: bool = True, - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - final_dropout: bool = True, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - block_idx: int = 0, - ): - super().__init__() - - # 1. Self Attention - self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.attn1 = Attention( - query_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - qk_norm="layer_norm" if qk_norm else None, - eps=1e-6, - bias=attention_bias, - out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), - ) - - # parallel - #self.attn1.parallel_manager = None - - # 2. Feed Forward - self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - # pab - self.attn_count = 0 - self.last_attn = None - self.block_idx = block_idx - #@torch.compiler.disable() - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - timestep=None, - video_flow_feature: Optional[torch.Tensor] = None, - fuser=None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( - hidden_states, encoder_hidden_states, temb - ) - # Tora Motion-guidance Fuser - if video_flow_feature is not None: - H, W = video_flow_feature.shape[-2:] - T = norm_hidden_states.shape[1] // H // W - h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W) - h = fuser(h, video_flow_feature.to(h), T=T) - norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) - del h, fuser - # attention - if enable_pab(): - broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx) - if enable_pab() and broadcast_attn: - attn_hidden_states, attn_encoder_hidden_states = self.last_attn - else: - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - ) - if enable_pab(): - self.last_attn = (attn_hidden_states, attn_encoder_hidden_states) - - hidden_states = hidden_states + gate_msa * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( - hidden_states, encoder_hidden_states, temb - ) - - # feed-forward - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - ff_output = self.ff(norm_hidden_states) - - hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] - - return hidden_states, encoder_hidden_states - - -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): - """ - A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). - - Parameters: - num_attention_heads (`int`, defaults to `30`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, defaults to `64`): - The number of channels in each head. - in_channels (`int`, defaults to `16`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `16`): - The number of channels in the output. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - time_embed_dim (`int`, defaults to `512`): - Output dimension of timestep embeddings. - text_embed_dim (`int`, defaults to `4096`): - Input dimension of text embeddings from the text encoder. - num_layers (`int`, defaults to `30`): - The number of layers of Transformer blocks to use. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. - sample_width (`int`, defaults to `90`): - The width of the input latents. - sample_height (`int`, defaults to `60`): - The height of the input latents. - sample_frames (`int`, defaults to `49`): - The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 - instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, - but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with - K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). - patch_size (`int`, defaults to `2`): - The size of the patches to use in the patch embedding layer. - temporal_compression_ratio (`int`, defaults to `4`): - The compression ratio across the temporal dimension. See documentation for `sample_frames`. - max_text_seq_length (`int`, defaults to `226`): - The maximum sequence length of the input text embeddings. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to use in feed-forward. - timestep_activation_fn (`str`, defaults to `"silu"`): - Activation function to use when generating the timestep embeddings. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, defaults to `1e-5`): - The epsilon value to use in normalization layers. - spatial_interpolation_scale (`float`, defaults to `1.875`): - Scaling factor to apply in 3D positional embeddings across spatial dimensions. - temporal_interpolation_scale (`float`, defaults to `1.0`): - Scaling factor to apply in 3D positional embeddings across temporal dimensions. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 64, - in_channels: int = 16, - out_channels: Optional[int] = 16, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - time_embed_dim: int = 512, - text_embed_dim: int = 4096, - num_layers: int = 30, - dropout: float = 0.0, - attention_bias: bool = True, - sample_width: int = 90, - sample_height: int = 60, - sample_frames: int = 49, - patch_size: int = 2, - temporal_compression_ratio: int = 4, - max_text_seq_length: int = 226, - activation_fn: str = "gelu-approximate", - timestep_activation_fn: str = "silu", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - spatial_interpolation_scale: float = 1.875, - temporal_interpolation_scale: float = 1.0, - use_rotary_positional_embeddings: bool = False, - use_learned_positional_embeddings: bool = False, - ): - super().__init__() - inner_dim = num_attention_heads * attention_head_dim - - post_patch_height = sample_height // patch_size - post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 - self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames - - # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed( - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - text_embed_dim=text_embed_dim, - bias=True, - sample_width=sample_width, - sample_height=sample_height, - sample_frames=sample_frames, - temporal_compression_ratio=temporal_compression_ratio, - max_text_seq_length=max_text_seq_length, - spatial_interpolation_scale=spatial_interpolation_scale, - temporal_interpolation_scale=temporal_interpolation_scale, - use_positional_embeddings=not use_rotary_positional_embeddings, - use_learned_positional_embeddings=use_learned_positional_embeddings, - ) - self.embedding_dropout = nn.Dropout(dropout) - - # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, - ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) - self.register_buffer("pos_embedding", pos_embedding, persistent=False) - - # 3. Time embeddings - self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) - self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - - # 4. Define spatio-temporal transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - CogVideoXBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - time_embed_dim=time_embed_dim, - dropout=dropout, - activation_fn=activation_fn, - attention_bias=attention_bias, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for _ in range(num_layers) - ] - ) - self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - - # 5. Output blocks - self.norm_out = AdaLayerNorm( - embedding_dim=time_embed_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - chunk_dim=1, - ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) - - self.gradient_checkpointing = False - - self.fuser_list = None - - # parallel - #self.parallel_manager = None - - # def enable_parallel(self, dp_size, sp_size, enable_cp): - # # update cfg parallel - # if enable_cp and sp_size % 2 == 0: - # sp_size = sp_size // 2 - # cp_size = 2 - # else: - # cp_size = 1 - - # self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size) - - # for _, module in self.named_modules(): - # if hasattr(module, "parallel_manager"): - # module.parallel_manager = self.parallel_manager - - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - return_dict: bool = True, - controlnet_states: torch.Tensor = None, - controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0, - video_flow_features: Optional[torch.Tensor] = None, - ): - # if self.parallel_manager.cp_size > 1: - # ( - # hidden_states, - # encoder_hidden_states, - # timestep, - # timestep_cond, - # image_rotary_emb, - # ) = batch_func( - # partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), - # hidden_states, - # encoder_hidden_states, - # timestep, - # timestep_cond, - # image_rotary_emb, - # ) - - batch_size, num_frames, channels, height, width = hidden_states.shape - - # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - - # 3. Position embedding - text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings: - seq_length = height * width * num_frames // (self.config.patch_size**2) - - pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] - hidden_states = hidden_states + pos_embeds - hidden_states = self.embedding_dropout(hidden_states) - - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - # if self.parallel_manager.sp_size > 1: - # set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group) - # hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) - - # 4. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - timestep=timesteps if enable_pab() else None, - video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, - fuser = self.fuser_list[i] if self.fuser_list is not None else None, - ) - if (controlnet_states is not None) and (i < len(controlnet_states)): - controlnet_states_block = controlnet_states[i] - controlnet_block_weight = 1.0 - if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights): - controlnet_block_weight = controlnet_weights[i] - elif isinstance(controlnet_weights, (float, int)): - controlnet_block_weight = controlnet_weights - - hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight - - #if self.parallel_manager.sp_size > 1: - # hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 5. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 6. Unpatchify - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - #if self.parallel_manager.cp_size > 1: - # output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) diff --git a/videosys/core/__init__.py b/videosys/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/videosys/core/pab_mgr.py b/videosys/core/pab_mgr.py deleted file mode 100644 index 6f19a50..0000000 --- a/videosys/core/pab_mgr.py +++ /dev/null @@ -1,232 +0,0 @@ - -PAB_MANAGER = None - - -class PABConfig: - def __init__( - self, - steps: int, - cross_broadcast: bool = False, - cross_threshold: list = None, - cross_range: int = None, - spatial_broadcast: bool = False, - spatial_threshold: list = None, - spatial_range: int = None, - temporal_broadcast: bool = False, - temporal_threshold: list = None, - temporal_range: int = None, - mlp_broadcast: bool = False, - mlp_spatial_broadcast_config: dict = None, - mlp_temporal_broadcast_config: dict = None, - ): - self.steps = steps - - self.cross_broadcast = cross_broadcast - self.cross_threshold = cross_threshold - self.cross_range = cross_range - - self.spatial_broadcast = spatial_broadcast - self.spatial_threshold = spatial_threshold - self.spatial_range = spatial_range - - self.temporal_broadcast = temporal_broadcast - self.temporal_threshold = temporal_threshold - self.temporal_range = temporal_range - - self.mlp_broadcast = mlp_broadcast - self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config - self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config - self.mlp_temporal_outputs = {} - self.mlp_spatial_outputs = {} - - -class PABManager: - def __init__(self, config: PABConfig): - self.config: PABConfig = config - - init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}." - init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}." - init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}." - init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}." - init_prompt += f" mlp broadcast: {config.mlp_broadcast}." - print(init_prompt) - - def if_broadcast_cross(self, timestep: int, count: int): - if ( - self.config.cross_broadcast - and (timestep is not None) - and (count % self.config.cross_range != 0) - and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1]) - ): - flag = True - else: - flag = False - count = (count + 1) % self.config.steps - return flag, count - - def if_broadcast_temporal(self, timestep: int, count: int): - if ( - self.config.temporal_broadcast - and (timestep is not None) - and (count % self.config.temporal_range != 0) - and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1]) - ): - flag = True - else: - flag = False - count = (count + 1) % self.config.steps - return flag, count - - def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int): - if ( - self.config.spatial_broadcast - and (timestep is not None) - and (count % self.config.spatial_range != 0) - and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1]) - ): - flag = True - else: - flag = False - count = (count + 1) % self.config.steps - return flag, count - - @staticmethod - def _is_t_in_skip_config(all_timesteps, timestep, config): - is_t_in_skip_config = False - skip_range = None - for key in config: - if key not in all_timesteps: - continue - index = all_timesteps.index(key) - skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])] - if timestep in skip_range: - is_t_in_skip_config = True - skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]] - break - return is_t_in_skip_config, skip_range - - def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): - if not self.config.mlp_broadcast: - return False, None, False, None - - if is_temporal: - cur_config = self.config.mlp_temporal_broadcast_config - else: - cur_config = self.config.mlp_spatial_broadcast_config - - is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config) - next_flag = False - if ( - self.config.mlp_broadcast - and (timestep is not None) - and (timestep in cur_config) - and (block_idx in cur_config[timestep]["block"]) - ): - flag = False - next_flag = True - count = count + 1 - elif ( - self.config.mlp_broadcast - and (timestep is not None) - and (is_t_in_skip_config) - and (block_idx in cur_config[skip_range[0]]["block"]) - ): - flag = True - count = 0 - else: - flag = False - - return flag, count, next_flag, skip_range - - def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False): - if is_temporal: - self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output - else: - self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output - - def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False): - skip_start_t = skip_range[0] - if is_temporal: - skip_output = ( - self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None) - if self.config.mlp_temporal_outputs is not None - else None - ) - else: - skip_output = ( - self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None) - if self.config.mlp_spatial_outputs is not None - else None - ) - - if skip_output is not None: - if timestep == skip_range[-1]: - # TODO: save memory - if is_temporal: - del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)] - else: - del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)] - else: - raise ValueError( - f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}" - ) - - return skip_output - - def get_spatial_mlp_outputs(self): - return self.config.mlp_spatial_outputs - - def get_temporal_mlp_outputs(self): - return self.config.mlp_temporal_outputs - - -def set_pab_manager(config: PABConfig): - global PAB_MANAGER - PAB_MANAGER = PABManager(config) - - -def enable_pab(): - if PAB_MANAGER is None: - return False - return ( - PAB_MANAGER.config.cross_broadcast - or PAB_MANAGER.config.spatial_broadcast - or PAB_MANAGER.config.temporal_broadcast - ) - - -def update_steps(steps: int): - if PAB_MANAGER is not None: - PAB_MANAGER.config.steps = steps - - -def if_broadcast_cross(timestep: int, count: int): - if not enable_pab(): - return False, count - return PAB_MANAGER.if_broadcast_cross(timestep, count) - - -def if_broadcast_temporal(timestep: int, count: int): - if not enable_pab(): - return False, count - return PAB_MANAGER.if_broadcast_temporal(timestep, count) - - -def if_broadcast_spatial(timestep: int, count: int, block_idx: int): - if not enable_pab(): - return False, count - return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx) - - -def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): - if not enable_pab(): - return False, count - return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal) - - -def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False): - return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal) - - -def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False): - return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal) diff --git a/videosys/core/pipeline.py b/videosys/core/pipeline.py deleted file mode 100644 index 3244749..0000000 --- a/videosys/core/pipeline.py +++ /dev/null @@ -1,44 +0,0 @@ -import inspect -from abc import abstractmethod - -import torch -from diffusers.pipelines.pipeline_utils import DiffusionPipeline - -class VideoSysPipeline(DiffusionPipeline): - def __init__(self): - super().__init__() - - @staticmethod - def set_eval_and_device(device: torch.device, *modules): - for module in modules: - module.eval() - module.to(device) - - @abstractmethod - def generate(self, *args, **kwargs): - pass - - def __call__(self, *args, **kwargs): - """ - In diffusers, it is a convention to call the pipeline object. - But in VideoSys, we will use the generate method for better prompt. - This is a wrapper for the generate method to support the diffusers usage. - """ - return self.generate(*args, **kwargs) - - @classmethod - def _get_signature_keys(cls, obj): - parameters = inspect.signature(obj.__init__).parameters - required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} - optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) - expected_modules = set(required_parameters.keys()) - {"self"} - # modify: remove the config module from the expected modules - expected_modules = expected_modules - {"config"} - - optional_names = list(optional_parameters) - for name in optional_names: - if name in cls._optional_components: - expected_modules.add(name) - optional_parameters.remove(name) - - return expected_modules, optional_parameters diff --git a/videosys/modules/__init__.py b/videosys/modules/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/videosys/modules/activations.py b/videosys/modules/activations.py deleted file mode 100644 index cf24149..0000000 --- a/videosys/modules/activations.py +++ /dev/null @@ -1,3 +0,0 @@ -import torch.nn as nn - -approx_gelu = lambda: nn.GELU(approximate="tanh") diff --git a/videosys/modules/downsampling.py b/videosys/modules/downsampling.py deleted file mode 100644 index 9455a32..0000000 --- a/videosys/modules/downsampling.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class CogVideoXDownsample3D(nn.Module): - # Todo: Wait for paper relase. - r""" - A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI - - Args: - in_channels (`int`): - Number of channels in the input image. - out_channels (`int`): - Number of channels produced by the convolution. - kernel_size (`int`, defaults to `3`): - Size of the convolving kernel. - stride (`int`, defaults to `2`): - Stride of the convolution. - padding (`int`, defaults to `0`): - Padding added to all four sides of the input. - compress_time (`bool`, defaults to `False`): - Whether or not to compress the time dimension. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - stride: int = 2, - padding: int = 0, - compress_time: bool = False, - ): - super().__init__() - - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) - self.compress_time = compress_time - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.compress_time: - batch_size, channels, frames, height, width = x.shape - - # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) - x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) - - if x.shape[-1] % 2 == 1: - x_first, x_rest = x[..., 0], x[..., 1:] - if x_rest.shape[-1] > 0: - # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) - x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) - - x = torch.cat([x_first[..., None], x_rest], dim=-1) - # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) - x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) - else: - # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) - x = F.avg_pool1d(x, kernel_size=2, stride=2) - # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) - x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) - - # Pad the tensor - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - batch_size, channels, frames, height, width = x.shape - # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) - x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) - x = self.conv(x) - # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) - x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - return x diff --git a/videosys/modules/embeddings.py b/videosys/modules/embeddings.py deleted file mode 100644 index 04eba82..0000000 --- a/videosys/modules/embeddings.py +++ /dev/null @@ -1,308 +0,0 @@ -import math -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint -from einops import rearrange - -class CogVideoXPatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - embed_dim: int = 1920, - text_embed_dim: int = 4096, - bias: bool = True, - ) -> None: - super().__init__() - self.patch_size = patch_size - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - self.text_proj = nn.Linear(text_embed_dim, embed_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - r""" - Args: - text_embeds (`torch.Tensor`): - Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). - image_embeds (`torch.Tensor`): - Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). - """ - text_embeds = self.text_proj(text_embeds) - - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) - image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) - image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] - image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] - - embeds = torch.cat( - [text_embeds, image_embeds], dim=1 - ).contiguous() # [batch, seq_length + num_frames x height x width, channels] - return embeds - - -class OpenSoraPatchEmbed3D(nn.Module): - """Video to Patch Embedding. - - Args: - patch_size (int): Patch token size. Default: (2,4,4). - in_chans (int): Number of input video channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__( - self, - patch_size=(2, 4, 4), - in_chans=3, - embed_dim=96, - norm_layer=None, - flatten=True, - ): - super().__init__() - self.patch_size = patch_size - self.flatten = flatten - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - """Forward function.""" - # padding - _, _, D, H, W = x.size() - if W % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) - if H % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) - if D % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) - - x = self.proj(x) # (B C T H W) - if self.norm is not None: - D, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) - if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC - return x - - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) - freqs = freqs.to(device=t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t, dtype): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - if t_freq.dtype != dtype: - t_freq = t_freq.to(dtype) - t_emb = self.mlp(t_freq) - return t_emb - - -class SizeEmbedder(TimestepEmbedder): - """ - Embeds scalar timesteps into vector representations. - """ - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - self.outdim = hidden_size - - def forward(self, s, bs): - if s.ndim == 1: - s = s[:, None] - assert s.ndim == 2 - if s.shape[0] != bs: - s = s.repeat(bs // s.shape[0], 1) - assert s.shape[0] == bs - b, dims = s.shape[0], s.shape[1] - s = rearrange(s, "b d -> (b d)") - s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) - s_emb = self.mlp(s_freq) - s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) - return s_emb - - @property - def dtype(self): - return next(self.parameters()).dtype - -def get_3d_rotary_pos_embed( - embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - RoPE for video tokens with 3D structure. - - Args: - embed_dim: (`int`): - The embedding dimension size, corresponding to hidden_size_head. - crops_coords (`Tuple[int]`): - The top-left and bottom-right coordinates of the crop. - grid_size (`Tuple[int]`): - The grid size of the spatial positional embedding (height, width). - temporal_size (`int`): - The size of the temporal dimension. - theta (`float`): - Scaling factor for frequency computation. - use_real (`bool`): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - - Returns: - `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. - """ - start, stop = crops_coords - grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) - grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) - - # Compute dimensions for each axis - dim_t = embed_dim // 4 - dim_h = embed_dim // 8 * 3 - dim_w = embed_dim // 8 * 3 - - # Temporal frequencies - freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) - grid_t = torch.from_numpy(grid_t).float() - freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) - freqs_t = freqs_t.repeat_interleave(2, dim=-1) - - # Spatial frequencies for height and width - freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) - freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) - grid_h = torch.from_numpy(grid_h).float() - grid_w = torch.from_numpy(grid_w).float() - freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) - freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) - freqs_h = freqs_h.repeat_interleave(2, dim=-1) - freqs_w = freqs_w.repeat_interleave(2, dim=-1) - - # Broadcast and concatenate tensors along specified dimension - def broadcast(tensors, dim=-1): - num_tensors = len(tensors) - shape_lens = {len(t.shape) for t in tensors} - assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" - shape_len = list(shape_lens)[0] - dim = (dim + shape_len) if dim < 0 else dim - dims = list(zip(*(list(t.shape) for t in tensors))) - expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] - assert all( - [*(len(set(t[1])) <= 2 for t in expandable_dims)] - ), "invalid dimensions for broadcastable concatenation" - max_dims = [(t[0], max(t[1])) for t in expandable_dims] - expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] - expanded_dims.insert(dim, (dim, dims[dim])) - expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) - tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] - return torch.cat(tensors, dim=dim) - - freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - - t, h, w, d = freqs.shape - freqs = freqs.view(t * h * w, d) - - # Generate sine and cosine components - sin = freqs.sin() - cos = freqs.cos() - - if use_real: - return cos, sin - else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - return freqs_cis - - -def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], - use_real: bool = True, - use_real_unbind_dim: int = -1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings - to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are - reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting - tensors contain rotary embeddings and are returned as real tensors. - - Args: - x (`torch.Tensor`): - Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply - freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - if use_real: - cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) - - if use_real_unbind_dim == -1: - # Use for example in Lumina - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - elif use_real_unbind_dim == -2: - # Use for example in Stable Audio - x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] - x_rotated = torch.cat([-x_imag, x_real], dim=-1) - else: - raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - - return out - else: - x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) - - return x_out.type_as(x) diff --git a/videosys/modules/normalization.py b/videosys/modules/normalization.py deleted file mode 100644 index 216d0cc..0000000 --- a/videosys/modules/normalization.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn - - -class CogVideoXLayerNormZero(nn.Module): - def __init__( - self, - conditioning_dim: int, - embedding_dim: int, - elementwise_affine: bool = True, - eps: float = 1e-5, - bias: bool = True, - ) -> None: - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) - self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) - - def forward( - self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) - hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] - encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] - return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] - - -class AdaLayerNorm(nn.Module): - r""" - Norm layer modified to incorporate timestep embeddings. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`, *optional*): The size of the embeddings dictionary. - output_dim (`int`, *optional*): - norm_elementwise_affine (`bool`, defaults to `False): - norm_eps (`bool`, defaults to `False`): - chunk_dim (`int`, defaults to `0`): - """ - - def __init__( - self, - embedding_dim: int, - num_embeddings: Optional[int] = None, - output_dim: Optional[int] = None, - norm_elementwise_affine: bool = False, - norm_eps: float = 1e-5, - chunk_dim: int = 0, - ): - super().__init__() - - self.chunk_dim = chunk_dim - output_dim = output_dim or embedding_dim * 2 - - if num_embeddings is not None: - self.emb = nn.Embedding(num_embeddings, embedding_dim) - else: - self.emb = None - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, output_dim) - self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) - - def forward( - self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None - ) -> torch.Tensor: - if self.emb is not None: - temb = self.emb(timestep) - - temb = self.linear(self.silu(temb)) - - if self.chunk_dim == 1: - # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the - # other if-branch. This branch is specific to CogVideoX for now. - shift, scale = temb.chunk(2, dim=1) - shift = shift[:, None, :] - scale = scale[:, None, :] - else: - scale, shift = temb.chunk(2, dim=0) - - x = self.norm(x) * (1 + scale) + shift - return x diff --git a/videosys/modules/upsampling.py b/videosys/modules/upsampling.py deleted file mode 100644 index f9a61b7..0000000 --- a/videosys/modules/upsampling.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class CogVideoXUpsample3D(nn.Module): - r""" - A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. - - Args: - in_channels (`int`): - Number of channels in the input image. - out_channels (`int`): - Number of channels produced by the convolution. - kernel_size (`int`, defaults to `3`): - Size of the convolving kernel. - stride (`int`, defaults to `1`): - Stride of the convolution. - padding (`int`, defaults to `1`): - Padding added to all four sides of the input. - compress_time (`bool`, defaults to `False`): - Whether or not to compress the time dimension. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - stride: int = 1, - padding: int = 1, - compress_time: bool = False, - ) -> None: - super().__init__() - - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) - self.compress_time = compress_time - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - if self.compress_time: - if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: - # split first frame - x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] - - x_first = F.interpolate(x_first, scale_factor=2.0) - x_rest = F.interpolate(x_rest, scale_factor=2.0) - x_first = x_first[:, :, None, :, :] - inputs = torch.cat([x_first, x_rest], dim=2) - elif inputs.shape[2] > 1: - inputs = F.interpolate(inputs, scale_factor=2.0) - else: - inputs = inputs.squeeze(2) - inputs = F.interpolate(inputs, scale_factor=2.0) - inputs = inputs[:, :, None, :, :] - else: - # only interpolate 2D - b, c, t, h, w = inputs.shape - inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - inputs = F.interpolate(inputs, scale_factor=2.0) - inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) - - b, c, t, h, w = inputs.shape - inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - inputs = self.conv(inputs) - inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) - - return inputs diff --git a/videosys/pab.py b/videosys/pab.py deleted file mode 100644 index 007e1b3..0000000 --- a/videosys/pab.py +++ /dev/null @@ -1,64 +0,0 @@ -class PABConfig: - def __init__( - self, - steps: int, - cross_broadcast: bool = False, - cross_threshold: list = None, - cross_range: int = None, - spatial_broadcast: bool = False, - spatial_threshold: list = None, - spatial_range: int = None, - temporal_broadcast: bool = False, - temporal_threshold: list = None, - temporal_range: int = None, - mlp_broadcast: bool = False, - mlp_spatial_broadcast_config: dict = None, - mlp_temporal_broadcast_config: dict = None, - ): - self.steps = steps - - self.cross_broadcast = cross_broadcast - self.cross_threshold = cross_threshold - self.cross_range = cross_range - - self.spatial_broadcast = spatial_broadcast - self.spatial_threshold = spatial_threshold - self.spatial_range = spatial_range - - self.temporal_broadcast = temporal_broadcast - self.temporal_threshold = temporal_threshold - self.temporal_range = temporal_range - - self.mlp_broadcast = mlp_broadcast - self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config - self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config - self.mlp_temporal_outputs = {} - self.mlp_spatial_outputs = {} - -class CogVideoXPABConfig(PABConfig): - def __init__( - self, - steps: int = 50, - spatial_broadcast: bool = True, - spatial_threshold: list = [100, 850], - spatial_range: int = 2, - temporal_broadcast: bool = False, - temporal_threshold: list = [100, 850], - temporal_range: int = 4, - cross_broadcast: bool = False, - cross_threshold: list = [100, 850], - cross_range: int = 6, - ): - super().__init__( - steps=steps, - spatial_broadcast=spatial_broadcast, - spatial_threshold=spatial_threshold, - spatial_range=spatial_range, - temporal_broadcast=temporal_broadcast, - temporal_threshold=temporal_threshold, - temporal_range=temporal_range, - cross_broadcast=cross_broadcast, - cross_threshold=cross_threshold, - cross_range=cross_range - - ) \ No newline at end of file