From 93edf24631d35b1f8222462b3f2da57c241f479c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 22 Sep 2024 20:47:32 +0300 Subject: [PATCH] initial experimental PAB support (only normal text2vid for now) --- nodes.py | 102 ++++- pipeline_cogvideox.py | 15 +- videosys/cogvideox_transformer_3d.py | 604 +++++++++++++++++++++++++++ videosys/core/__init__.py | 0 videosys/core/comm.py | 406 ++++++++++++++++++ videosys/core/pab_mgr.py | 232 ++++++++++ videosys/core/pipeline.py | 52 +++ videosys/modules/__init__.py | 0 videosys/modules/activations.py | 3 + videosys/modules/attentions.py | 205 +++++++++ videosys/modules/downsampling.py | 71 ++++ videosys/modules/embeddings.py | 412 ++++++++++++++++++ videosys/modules/normalization.py | 102 +++++ videosys/modules/upsampling.py | 67 +++ videosys/utils/logging.py | 32 ++ videosys/utils/utils.py | 92 ++++ 16 files changed, 2384 insertions(+), 11 deletions(-) create mode 100644 videosys/cogvideox_transformer_3d.py create mode 100644 videosys/core/__init__.py create mode 100644 videosys/core/comm.py create mode 100644 videosys/core/pab_mgr.py create mode 100644 videosys/core/pipeline.py create mode 100644 videosys/modules/__init__.py create mode 100644 videosys/modules/activations.py create mode 100644 videosys/modules/attentions.py create mode 100644 videosys/modules/downsampling.py create mode 100644 videosys/modules/embeddings.py create mode 100644 videosys/modules/normalization.py create mode 100644 videosys/modules/upsampling.py create mode 100644 videosys/utils/logging.py create mode 100644 videosys/utils/utils.py diff --git a/nodes.py b/nodes.py index 59d5621..944db11 100644 --- a/nodes.py +++ b/nodes.py @@ -15,7 +15,6 @@ from diffusers.schedulers import ( HeunDiscreteScheduler, SASolverScheduler, DEISMultistepScheduler, - DDIMInverseScheduler ) scheduler_mapping = { @@ -50,6 +49,88 @@ log = logging.getLogger(__name__) script_directory = os.path.dirname(os.path.abspath(__file__)) +class PABConfig: + def __init__( + self, + steps: int, + cross_broadcast: bool = False, + cross_threshold: list = None, + cross_range: int = None, + spatial_broadcast: bool = False, + spatial_threshold: list = None, + spatial_range: int = None, + temporal_broadcast: bool = False, + temporal_threshold: list = None, + temporal_range: int = None, + mlp_broadcast: bool = False, + mlp_spatial_broadcast_config: dict = None, + mlp_temporal_broadcast_config: dict = None, + ): + self.steps = steps + + self.cross_broadcast = cross_broadcast + self.cross_threshold = cross_threshold + self.cross_range = cross_range + + self.spatial_broadcast = spatial_broadcast + self.spatial_threshold = spatial_threshold + self.spatial_range = spatial_range + + self.temporal_broadcast = temporal_broadcast + self.temporal_threshold = temporal_threshold + self.temporal_range = temporal_range + + self.mlp_broadcast = mlp_broadcast + self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config + self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config + self.mlp_temporal_outputs = {} + self.mlp_spatial_outputs = {} + +class CogVideoXPABConfig(PABConfig): + def __init__( + self, + steps: int = 50, + spatial_broadcast: bool = True, + spatial_threshold: list = [100, 850], + spatial_range: int = 2, + ): + super().__init__( + steps=steps, + spatial_broadcast=spatial_broadcast, + spatial_threshold=spatial_threshold, + spatial_range=spatial_range, + ) + +from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB + +class CogVideoPABConfig: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}), + "pab_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), + "pab_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), + "pab_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), + "steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ), + } + } + + RETURN_TYPES = ("PAB_CONFIG",) + RETURN_NAMES = ("pab_config", ) + FUNCTION = "config" + CATEGORY = "CogVideoWrapper" + + def config(self, spatial_broadcast, pab_threshold_start, pab_threshold_end, pab_range, steps): + + pab_config = CogVideoXPABConfig( + steps=steps, + spatial_broadcast=spatial_broadcast, + spatial_threshold=[pab_threshold_end, pab_threshold_start], + spatial_range=pab_range + ) + + return (pab_config, ) + class DownloadAndLoadCogVideoModel: @classmethod def INPUT_TYPES(s): @@ -74,6 +155,7 @@ class DownloadAndLoadCogVideoModel: "fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}), "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), + "pab_config": ("PAB_CONFIG", {"default": None}), } } @@ -82,7 +164,7 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False): + def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() @@ -129,7 +211,10 @@ class DownloadAndLoadCogVideoModel: if "Fun" in model: transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder="transformer") else: - transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer") + if pab_config is not None: + transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder="transformer") + else: + transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer") transformer = transformer.to(dtype).to(offload_device) @@ -151,7 +236,7 @@ class DownloadAndLoadCogVideoModel: with open(scheduler_path) as f: scheduler_config = json.load(f) - scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) + scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) # VAE if "Fun" in model: @@ -159,7 +244,7 @@ class DownloadAndLoadCogVideoModel: pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) else: vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) - pipe = CogVideoXPipeline(vae, transformer, scheduler) + pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() @@ -729,7 +814,6 @@ class CogVideoDecode: video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt") video = video[0].permute(0, 2, 3, 1).cpu().float() - print(video.min(), video.max()) return (video,) @@ -979,7 +1063,8 @@ NODE_CLASS_MAPPINGS = { "CogVideoXFunSampler": CogVideoXFunSampler, "CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler, "CogVideoTextEncodeCombine": CogVideoTextEncodeCombine, - "DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel + "DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel, + "CogVideoPABConfig": CogVideoPABConfig } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -991,5 +1076,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoXFunSampler": "CogVideoXFun Sampler", "CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler", "CogVideoTextEncodeCombine": "CogVideo TextEncode Combine", - "DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model" + "DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model", + "CogVideoPABConfig": "CogVideo PABConfig" } diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 5465e91..3f460a6 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -32,6 +32,10 @@ from comfy.utils import ProgressBar logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from .videosys.core.pipeline import VideoSysPipeline +from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB +from .videosys.core.pab_mgr import set_pab_manager + def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): tw = tgt_width th = tgt_height @@ -108,7 +112,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps -class CogVideoXPipeline(DiffusionPipeline): +class CogVideoXPipeline(VideoSysPipeline): r""" Pipeline for text-to-video generation using CogVideoX. @@ -137,9 +141,10 @@ class CogVideoXPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKLCogVideoX, - transformer: CogVideoXTransformer3DModel, + transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB], scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], - original_mask = None + original_mask = None, + pab_config = None ): super().__init__() @@ -155,6 +160,10 @@ class CogVideoXPipeline(DiffusionPipeline): self.original_mask = original_mask self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + if pab_config is not None: + print(pab_config) + set_pab_manager(pab_config) + def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, ): diff --git a/videosys/cogvideox_transformer_3d.py b/videosys/cogvideox_transformer_3d.py new file mode 100644 index 0000000..aeed7d3 --- /dev/null +++ b/videosys/cogvideox_transformer_3d.py @@ -0,0 +1,604 @@ +# Adapted from CogVideo + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# CogVideo: https://github.com/THUDM/CogVideo +# diffusers: https://github.com/huggingface/diffusers +# -------------------------------------------------------- + +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import is_torch_version +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch import nn + +from .core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence +from .core.pab_mgr import enable_pab, if_broadcast_spatial +#from .core.parallel_mgr import ParallelManager +from .modules.embeddings import apply_rotary_emb +from .utils.utils import batch_func + +from .modules.embeddings import CogVideoXPatchEmbed +from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero + + +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # if attn.parallel_manager.sp_size > 1: + # assert ( + # attn.heads % attn.parallel_manager.sp_size == 0 + # ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}" + # attn_heads = attn.heads // attn.parallel_manager.sp_size + # query, key, value = map( + # lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1), + # [query, key, value], + # ) + + attn_heads = attn.heads + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn_heads + + query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + emb_len = image_rotary_emb[0].shape[0] + query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb( + query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb + ) + if not attn.is_cross_attention: + key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb( + key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb + ) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) + + #if attn.parallel_manager.sp_size > 1: + # hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class FusedCogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + block_idx: int = 0, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), + ) + + # parallel + #self.attn1.parallel_manager = None + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # pab + self.attn_count = 0 + self.last_attn = None + self.block_idx = block_idx + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + timestep=None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + if enable_pab(): + broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx) + if enable_pab() and broadcast_attn: + attn_hidden_states, attn_encoder_hidden_states = self.last_attn + else: + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + if enable_pab(): + self.last_attn = (attn_hidden_states, attn_encoder_hidden_states) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + post_patch_height = sample_height // patch_size + post_patch_width = sample_width // patch_size + post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 + self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. 3D positional embeddings + spatial_pos_embedding = get_3d_sincos_pos_embed( + inner_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + spatial_interpolation_scale, + temporal_interpolation_scale, + ) + spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) + pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) + pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) + self.register_buffer("pos_embedding", pos_embedding, persistent=False) + + # 3. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 4. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 5. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + # parallel + #self.parallel_manager = None + + # def enable_parallel(self, dp_size, sp_size, enable_cp): + # # update cfg parallel + # if enable_cp and sp_size % 2 == 0: + # sp_size = sp_size // 2 + # cp_size = 2 + # else: + # cp_size = 1 + + # self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size) + + # for _, module in self.named_modules(): + # if hasattr(module, "parallel_manager"): + # module.parallel_manager = self.parallel_manager + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + ): + # if self.parallel_manager.cp_size > 1: + # ( + # hidden_states, + # encoder_hidden_states, + # timestep, + # timestep_cond, + # image_rotary_emb, + # ) = batch_func( + # partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), + # hidden_states, + # encoder_hidden_states, + # timestep, + # timestep_cond, + # image_rotary_emb, + # ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + + # 3. Position embedding + text_seq_length = encoder_hidden_states.shape[1] + if not self.config.use_rotary_positional_embeddings: + seq_length = height * width * num_frames // (self.config.patch_size**2) + + pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] + hidden_states = hidden_states + pos_embeds + hidden_states = self.embedding_dropout(hidden_states) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # if self.parallel_manager.sp_size > 1: + # set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group) + # hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) + + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + timestep=timesteps if enable_pab() else None, + ) + + #if self.parallel_manager.sp_size > 1: + # hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 5. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 6. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + #if self.parallel_manager.cp_size > 1: + # output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/videosys/core/__init__.py b/videosys/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/videosys/core/comm.py b/videosys/core/comm.py new file mode 100644 index 0000000..3b44402 --- /dev/null +++ b/videosys/core/comm.py @@ -0,0 +1,406 @@ +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor +from torch.distributed import ProcessGroup + +# ====================================================== +# Model +# ====================================================== + + +def model_sharding(model: torch.nn.Module): + global_rank = dist.get_rank() + world_size = dist.get_world_size() + for _, param in model.named_parameters(): + padding_size = (world_size - param.numel() % world_size) % world_size + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padding_param = param.data.view(-1) + splited_params = padding_param.split(padding_param.numel() // world_size) + splited_params = splited_params[global_rank] + param.data = splited_params + + +# ====================================================== +# AllGather & ReduceScatter +# ====================================================== + + +class AsyncAllGatherForTwo(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + inputs: Tensor, + weight: Tensor, + bias: Tensor, + sp_rank: int, + sp_size: int, + group: Optional[ProcessGroup] = None, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + from torch.distributed._functional_collectives import all_gather_tensor + + ctx.group = group + ctx.sp_rank = sp_rank + ctx.sp_size = sp_size + + # all gather inputs + all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group) + # compute local qkv + local_qkv = F.linear(inputs, weight, bias).unsqueeze(0) + + # remote compute + remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1]) + # compute remote qkv + remote_qkv = F.linear(remote_inputs, weight, bias) + + # concat local and remote qkv + if sp_rank == 0: + qkv = torch.cat([local_qkv, remote_qkv], dim=0) + else: + qkv = torch.cat([remote_qkv, local_qkv], dim=0) + qkv = rearrange(qkv, "sp b n c -> b (sp n) c") + + ctx.save_for_backward(inputs, weight, remote_inputs) + return qkv + + @staticmethod + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + from torch.distributed._functional_collectives import reduce_scatter_tensor + + group = ctx.group + sp_rank = ctx.sp_rank + sp_size = ctx.sp_size + inputs, weight, remote_inputs = ctx.saved_tensors + + # split qkv_grad + qkv_grad = grad_outputs[0] + qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size) + qkv_grad = torch.chunk(qkv_grad, 2, dim=0) + if sp_rank == 0: + local_qkv_grad, remote_qkv_grad = qkv_grad + else: + remote_qkv_grad, local_qkv_grad = qkv_grad + + # compute remote grad + remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0) + weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0) + bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0) + + # launch async reduce scatter + remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad) + if sp_rank == 0: + remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0) + else: + remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0) + remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group) + + # compute local grad and wait for reduce scatter + local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0) + weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0) + bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0) + + # sum remote and local grad + inputs_grad = remote_inputs_grad + local_input_grad + return inputs_grad, weight_grad, bias_grad, None, None, None + + +class AllGather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.unsqueeze(0), None + + buffer_shape = (comm_size,) + inputs.shape + outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) + if not overlap: + dist.all_gather(buffer_list, inputs, group=group) + return outputs, None + else: + handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) + return outputs, handle + + @staticmethod + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + return ( + ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], + None, + None, + ) + + +class ReduceScatter(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + inputs: Tensor, + group: ProcessGroup, + overlap: bool = False, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.squeeze(0), None + + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output_shape = inputs.shape[1:] + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) + if not overlap: + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs, None + else: + handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) + return outputs, handle + + @staticmethod + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + # TODO: support async backward + return ( + AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], + None, + None, + ) + + +# ====================================================== +# AlltoAll +# ====================================================== + + +def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + world_size = dist.get_world_size(process_group) + + return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, *grad_output): + process_group = ctx.process_group + scatter_dim = ctx.gather_dim + gather_dim = ctx.scatter_dim + return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + return (return_grad, None, None, None) + + +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) + + +# ====================================================== +# Sequence Gather & Split +# ====================================================== + + +def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): + # skip if only one rank involved + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + if world_size == 1: + return input_ + + if pad > 0: + pad_size = list(input_.shape) + pad_size[dim] = pad + input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) + + dim_size = input_.size(dim) + assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})" + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + return output + + +def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): + # skip if only one rank involved + input_ = input_.contiguous() + world_size = dist.get_world_size(pg) + dist.get_rank(pg) + + if world_size == 1: + return input_ + + # all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + assert input_.device.type == "cuda" + torch.distributed.all_gather(tensor_list, input_, group=pg) + + # concat + output = torch.cat(tensor_list, dim=dim) + + if pad > 0: + output = output.narrow(dim, 0, output.size(dim) - pad) + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """ + Gather the input sequence. + + Args: + input_: input matrix. + process_group: process group. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather_sequence_func(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim, grad_scale, pad): + ctx.process_group = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + ctx.pad = pad + return _gather_sequence_func(input_, process_group, dim, pad) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.process_group) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.process_group) + + return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split sequence. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split_sequence_func(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim, grad_scale, pad): + ctx.process_group = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + ctx.pad = pad + return _split_sequence_func(input_, process_group, dim, pad) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.process_group) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.process_group) + return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None + + +def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0): + return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad) + + +def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0): + return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad) + + +# ============================== +# Pad +# ============================== + +PAD_DICT = {} + + +def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup): + sp_size = dist.get_world_size(parallel_group) + pad = (sp_size - (dim_size % sp_size)) % sp_size + global PAD_DICT + PAD_DICT[name] = pad + + +def get_pad(name) -> int: + return PAD_DICT[name] + + +def all_to_all_with_pad( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + scatter_dim: int = 2, + gather_dim: int = 1, + scatter_pad: int = 0, + gather_pad: int = 0, +): + if scatter_pad > 0: + pad_shape = list(input_.shape) + pad_shape[scatter_dim] = scatter_pad + pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) + input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) + + assert ( + input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0 + ), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})" + input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) + + if gather_pad > 0: + input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) + + return input_ diff --git a/videosys/core/pab_mgr.py b/videosys/core/pab_mgr.py new file mode 100644 index 0000000..6f19a50 --- /dev/null +++ b/videosys/core/pab_mgr.py @@ -0,0 +1,232 @@ + +PAB_MANAGER = None + + +class PABConfig: + def __init__( + self, + steps: int, + cross_broadcast: bool = False, + cross_threshold: list = None, + cross_range: int = None, + spatial_broadcast: bool = False, + spatial_threshold: list = None, + spatial_range: int = None, + temporal_broadcast: bool = False, + temporal_threshold: list = None, + temporal_range: int = None, + mlp_broadcast: bool = False, + mlp_spatial_broadcast_config: dict = None, + mlp_temporal_broadcast_config: dict = None, + ): + self.steps = steps + + self.cross_broadcast = cross_broadcast + self.cross_threshold = cross_threshold + self.cross_range = cross_range + + self.spatial_broadcast = spatial_broadcast + self.spatial_threshold = spatial_threshold + self.spatial_range = spatial_range + + self.temporal_broadcast = temporal_broadcast + self.temporal_threshold = temporal_threshold + self.temporal_range = temporal_range + + self.mlp_broadcast = mlp_broadcast + self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config + self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config + self.mlp_temporal_outputs = {} + self.mlp_spatial_outputs = {} + + +class PABManager: + def __init__(self, config: PABConfig): + self.config: PABConfig = config + + init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}." + init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}." + init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}." + init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}." + init_prompt += f" mlp broadcast: {config.mlp_broadcast}." + print(init_prompt) + + def if_broadcast_cross(self, timestep: int, count: int): + if ( + self.config.cross_broadcast + and (timestep is not None) + and (count % self.config.cross_range != 0) + and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1]) + ): + flag = True + else: + flag = False + count = (count + 1) % self.config.steps + return flag, count + + def if_broadcast_temporal(self, timestep: int, count: int): + if ( + self.config.temporal_broadcast + and (timestep is not None) + and (count % self.config.temporal_range != 0) + and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1]) + ): + flag = True + else: + flag = False + count = (count + 1) % self.config.steps + return flag, count + + def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int): + if ( + self.config.spatial_broadcast + and (timestep is not None) + and (count % self.config.spatial_range != 0) + and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1]) + ): + flag = True + else: + flag = False + count = (count + 1) % self.config.steps + return flag, count + + @staticmethod + def _is_t_in_skip_config(all_timesteps, timestep, config): + is_t_in_skip_config = False + skip_range = None + for key in config: + if key not in all_timesteps: + continue + index = all_timesteps.index(key) + skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])] + if timestep in skip_range: + is_t_in_skip_config = True + skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]] + break + return is_t_in_skip_config, skip_range + + def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): + if not self.config.mlp_broadcast: + return False, None, False, None + + if is_temporal: + cur_config = self.config.mlp_temporal_broadcast_config + else: + cur_config = self.config.mlp_spatial_broadcast_config + + is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config) + next_flag = False + if ( + self.config.mlp_broadcast + and (timestep is not None) + and (timestep in cur_config) + and (block_idx in cur_config[timestep]["block"]) + ): + flag = False + next_flag = True + count = count + 1 + elif ( + self.config.mlp_broadcast + and (timestep is not None) + and (is_t_in_skip_config) + and (block_idx in cur_config[skip_range[0]]["block"]) + ): + flag = True + count = 0 + else: + flag = False + + return flag, count, next_flag, skip_range + + def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False): + if is_temporal: + self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output + else: + self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output + + def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False): + skip_start_t = skip_range[0] + if is_temporal: + skip_output = ( + self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None) + if self.config.mlp_temporal_outputs is not None + else None + ) + else: + skip_output = ( + self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None) + if self.config.mlp_spatial_outputs is not None + else None + ) + + if skip_output is not None: + if timestep == skip_range[-1]: + # TODO: save memory + if is_temporal: + del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)] + else: + del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)] + else: + raise ValueError( + f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}" + ) + + return skip_output + + def get_spatial_mlp_outputs(self): + return self.config.mlp_spatial_outputs + + def get_temporal_mlp_outputs(self): + return self.config.mlp_temporal_outputs + + +def set_pab_manager(config: PABConfig): + global PAB_MANAGER + PAB_MANAGER = PABManager(config) + + +def enable_pab(): + if PAB_MANAGER is None: + return False + return ( + PAB_MANAGER.config.cross_broadcast + or PAB_MANAGER.config.spatial_broadcast + or PAB_MANAGER.config.temporal_broadcast + ) + + +def update_steps(steps: int): + if PAB_MANAGER is not None: + PAB_MANAGER.config.steps = steps + + +def if_broadcast_cross(timestep: int, count: int): + if not enable_pab(): + return False, count + return PAB_MANAGER.if_broadcast_cross(timestep, count) + + +def if_broadcast_temporal(timestep: int, count: int): + if not enable_pab(): + return False, count + return PAB_MANAGER.if_broadcast_temporal(timestep, count) + + +def if_broadcast_spatial(timestep: int, count: int, block_idx: int): + if not enable_pab(): + return False, count + return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx) + + +def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): + if not enable_pab(): + return False, count + return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal) + + +def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False): + return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal) + + +def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False): + return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal) diff --git a/videosys/core/pipeline.py b/videosys/core/pipeline.py new file mode 100644 index 0000000..75b79d3 --- /dev/null +++ b/videosys/core/pipeline.py @@ -0,0 +1,52 @@ +import inspect +from abc import abstractmethod +from dataclasses import dataclass + +import torch +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput + + +class VideoSysPipeline(DiffusionPipeline): + def __init__(self): + super().__init__() + + @staticmethod + def set_eval_and_device(device: torch.device, *modules): + for module in modules: + module.eval() + module.to(device) + + @abstractmethod + def generate(self, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs): + """ + In diffusers, it is a convention to call the pipeline object. + But in VideoSys, we will use the generate method for better prompt. + This is a wrapper for the generate method to support the diffusers usage. + """ + return self.generate(*args, **kwargs) + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + # modify: remove the config module from the expected modules + expected_modules = expected_modules - {"config"} + + optional_names = list(optional_parameters) + for name in optional_names: + if name in cls._optional_components: + expected_modules.add(name) + optional_parameters.remove(name) + + return expected_modules, optional_parameters + + +@dataclass +class VideoSysPipelineOutput(BaseOutput): + video: torch.Tensor diff --git a/videosys/modules/__init__.py b/videosys/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/videosys/modules/activations.py b/videosys/modules/activations.py new file mode 100644 index 0000000..cf24149 --- /dev/null +++ b/videosys/modules/activations.py @@ -0,0 +1,3 @@ +import torch.nn as nn + +approx_gelu = lambda: nn.GELU(approximate="tanh") diff --git a/videosys/modules/attentions.py b/videosys/modules/attentions.py new file mode 100644 index 0000000..8e2c20c --- /dev/null +++ b/videosys/modules/attentions.py @@ -0,0 +1,205 @@ +from dataclasses import dataclass +from typing import Iterable, List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from videosys.models.modules.normalization import LlamaRMSNorm + + +class OpenSoraAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = LlamaRMSNorm, + enable_flash_attn: bool = False, + rope=None, + qk_norm_legacy: bool = False, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flash_attn = enable_flash_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.qk_norm_legacy = qk_norm_legacy + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = False + if rope is not None: + self.rope = True + self.rotary_emb = rope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + # flash attn is not memory efficient for small sequences, this is empirical + enable_flash_attn = self.enable_flash_attn and (N > B) + qkv = self.qkv(x) + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + + qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + if self.qk_norm_legacy: + # WARNING: this may be a bug + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + q, k = self.q_norm(q), self.k_norm(k) + else: + q, k = self.q_norm(q), self.k_norm(k) + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + + if enable_flash_attn: + from flash_attn import flash_attn_func + + # (B, #heads, N, #dim) -> (B, N, #heads, #dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + x = flash_attn_func( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + softmax_scale=self.scale, + ) + else: + x = F.scaled_dot_product_attention(q, k, v) + + x_output_shape = (B, N, C) + if not enable_flash_attn: + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class OpenSoraMultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flash_attn=False): + super(OpenSoraMultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = nn.Linear(d_model, d_model) + self.kv_linear = nn.Linear(d_model, d_model * 2) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(d_model, d_model) + self.proj_drop = nn.Dropout(proj_drop) + self.enable_flash_attn = enable_flash_attn + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + if self.enable_flash_attn: + x = self.flash_attn_impl(q, k, v, mask, B, N, C) + else: + x = self.torch_impl(q, k, v, mask, B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + def flash_attn_impl(self, q, k, v, mask, B, N, C): + from flash_attn import flash_attn_varlen_func + + q_seqinfo = _SeqLenInfo.from_seqlens([N] * B) + k_seqinfo = _SeqLenInfo.from_seqlens(mask) + + x = flash_attn_varlen_func( + q.view(-1, self.num_heads, self.head_dim), + k.view(-1, self.num_heads, self.head_dim), + v.view(-1, self.num_heads, self.head_dim), + cu_seqlens_q=q_seqinfo.seqstart.cuda(), + cu_seqlens_k=k_seqinfo.seqstart.cuda(), + max_seqlen_q=q_seqinfo.max_seqlen, + max_seqlen_k=k_seqinfo.max_seqlen, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + x = x.view(B, N, C) + return x + + def torch_impl(self, q, k, v, mask, B, N, C): + q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + + attn_mask = torch.zeros(B, 1, N, k.shape[2], dtype=torch.bool, device=q.device) + for i, m in enumerate(mask): + attn_mask[i, :, :, :m] = -1e9 + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = out.transpose(1, 2).contiguous().view(B, N, C) + return x + + +@dataclass +class _SeqLenInfo: + """ + from xformers + + (Internal) Represents the division of a dimension into blocks. + For example, to represents a dimension of length 7 divided into + three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 2, 5, 7] + seqstart: torch.IntTensor([0, 2, 5, 7]) + """ + + seqstart: torch.Tensor + max_seqlen: int + min_seqlen: int + seqstart_py: List[int] + + def to(self, device: torch.device) -> None: + self.seqstart = self.seqstart.to(device, non_blocking=True) + + def intervals(self) -> Iterable[Tuple[int, int]]: + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = [0] + max_seqlen = -1 + min_seqlen = -1 + for seqlen in seqlens: + min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + return cls( + max_seqlen=max_seqlen, + min_seqlen=min_seqlen, + seqstart=seqstart, + seqstart_py=seqstart_py, + ) diff --git a/videosys/modules/downsampling.py b/videosys/modules/downsampling.py new file mode 100644 index 0000000..9455a32 --- /dev/null +++ b/videosys/modules/downsampling.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CogVideoXDownsample3D(nn.Module): + # Todo: Wait for paper relase. + r""" + A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `2`): + Stride of the convolution. + padding (`int`, defaults to `0`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 2, + padding: int = 0, + compress_time: bool = False, + ): + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.compress_time: + batch_size, channels, frames, height, width = x.shape + + # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) + x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) + + if x.shape[-1] % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + + x = torch.cat([x_first[..., None], x_rest], dim=-1) + # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) + x = F.avg_pool1d(x, kernel_size=2, stride=2) + # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + + # Pad the tensor + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + batch_size, channels, frames, height, width = x.shape + # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) + x = self.conv(x) + # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) + x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x diff --git a/videosys/modules/embeddings.py b/videosys/modules/embeddings.py new file mode 100644 index 0000000..13dd629 --- /dev/null +++ b/videosys/modules/embeddings.py @@ -0,0 +1,412 @@ +import functools +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from timm.models.vision_transformer import Mlp + + +class CogVideoXPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, + ) -> None: + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + r""" + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + + batch, num_frames, channels, height, width = image_embeds.shape + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + + embeds = torch.cat( + [text_embeds, image_embeds], dim=1 + ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + return embeds + + +class OpenSoraPatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # (B C T H W) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs // s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) + s_emb = self.mlp(s_freq) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class OpenSoraCaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__( + self, + in_channels, + hidden_size, + uncond_prob, + act_layer=nn.GELU(approximate="tanh"), + token_num=120, + ): + super().__init__() + self.y_proj = Mlp( + in_features=in_channels, + hidden_features=hidden_size, + out_features=hidden_size, + act_layer=act_layer, + drop=0, + ) + self.register_buffer( + "y_embedding", + torch.randn(token_num, in_channels) / in_channels**0.5, + ) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +class OpenSoraPositionEmbedding2D(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + assert dim % 4 == 0, "dim must be divisible by 4" + half_dim = dim // 2 + inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def _get_sin_cos_emb(self, t: torch.Tensor): + out = torch.einsum("i,d->id", t, self.inv_freq) + emb_cos = torch.cos(out) + emb_sin = torch.sin(out) + return torch.cat((emb_sin, emb_cos), dim=-1) + + @functools.lru_cache(maxsize=512) + def _get_cached_emb( + self, + device: torch.device, + dtype: torch.dtype, + h: int, + w: int, + scale: float = 1.0, + base_size: Optional[int] = None, + ): + grid_h = torch.arange(h, device=device) / scale + grid_w = torch.arange(w, device=device) / scale + if base_size is not None: + grid_h *= base_size / h + grid_w *= base_size / w + grid_h, grid_w = torch.meshgrid( + grid_w, + grid_h, + indexing="ij", + ) # here w goes first + grid_h = grid_h.t().reshape(-1) + grid_w = grid_w.t().reshape(-1) + emb_h = self._get_sin_cos_emb(grid_h) + emb_w = self._get_sin_cos_emb(grid_w) + return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype) + + def forward( + self, + x: torch.Tensor, + h: int, + w: int, + scale: Optional[float] = 1.0, + base_size: Optional[int] = None, + ) -> torch.Tensor: + return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size) + + +def get_3d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) + grid_t = torch.from_numpy(grid_t).float() + freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) + freqs_t = freqs_t.repeat_interleave(2, dim=-1) + + # Spatial frequencies for height and width + freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) + freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) + grid_h = torch.from_numpy(grid_h).float() + grid_w = torch.from_numpy(grid_w).float() + freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) + freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) + freqs_h = freqs_h.repeat_interleave(2, dim=-1) + freqs_w = freqs_w.repeat_interleave(2, dim=-1) + + # Broadcast and concatenate tensors along specified dimension + def broadcast(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*(len(set(t[1])) <= 2 for t in expandable_dims)] + ), "invalid dimensions for broadcastable concatenation" + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + + t, h, w, d = freqs.shape + freqs = freqs.view(t * h * w, d) + + # Generate sine and cosine components + sin = freqs.sin() + cos = freqs.cos() + + if use_real: + return cos, sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Use for example in Lumina + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Use for example in Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) diff --git a/videosys/modules/normalization.py b/videosys/modules/normalization.py new file mode 100644 index 0000000..7985e56 --- /dev/null +++ b/videosys/modules/normalization.py @@ -0,0 +1,102 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class CogVideoXLayerNormZero(nn.Module): + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX for now. + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x diff --git a/videosys/modules/upsampling.py b/videosys/modules/upsampling.py new file mode 100644 index 0000000..f9a61b7 --- /dev/null +++ b/videosys/modules/upsampling.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CogVideoXUpsample3D(nn.Module): + r""" + A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `1`): + Stride of the convolution. + padding (`int`, defaults to `1`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + compress_time: bool = False, + ) -> None: + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if self.compress_time: + if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + inputs = torch.cat([x_first, x_rest], dim=2) + elif inputs.shape[2] > 1: + inputs = F.interpolate(inputs, scale_factor=2.0) + else: + inputs = inputs.squeeze(2) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + # only interpolate 2D + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = self.conv(inputs) + inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) + + return inputs diff --git a/videosys/utils/logging.py b/videosys/utils/logging.py new file mode 100644 index 0000000..896a4d6 --- /dev/null +++ b/videosys/utils/logging.py @@ -0,0 +1,32 @@ +import logging + +import torch.distributed as dist +from rich.logging import RichHandler + + +def create_logger(): + """ + Create a logger that writes to a log file and stdout. + """ + logger = logging.getLogger(__name__) + return logger + + +def init_dist_logger(): + """ + Update the logger to write to a log file. + """ + global logger + if dist.get_rank() == 0: + logger = logging.getLogger(__name__) + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) + formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + + +logger = create_logger() diff --git a/videosys/utils/utils.py b/videosys/utils/utils.py new file mode 100644 index 0000000..622a36d --- /dev/null +++ b/videosys/utils/utils.py @@ -0,0 +1,92 @@ +import os +import random + +import imageio +import numpy as np +import torch +import torch.distributed as dist +from omegaconf import DictConfig, ListConfig, OmegaConf + + +def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def set_seed(seed, dp_rank=None): + if seed == -1: + seed = random.randint(0, 1000000) + + if dp_rank is not None: + seed = torch.tensor(seed, dtype=torch.int64).cuda() + if dist.get_world_size() > 1: + dist.broadcast(seed, 0) + seed = seed + dp_rank + + seed = int(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def str_to_dtype(x: str): + if x == "fp32": + return torch.float32 + elif x == "fp16": + return torch.float16 + elif x == "bf16": + return torch.bfloat16 + else: + raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}") + + +def batch_func(func, *args): + """ + Apply a function to each element of a batch. + """ + batch = [] + for arg in args: + if isinstance(arg, torch.Tensor) and arg.shape[0] == 2: + batch.append(func(arg)) + else: + batch.append(arg) + + return batch + + +def merge_args(args1, args2): + """ + Merge two argparse Namespace objects. + """ + if args2 is None: + return args1 + + for k in args2._content.keys(): + if k in args1.__dict__: + v = getattr(args2, k) + if isinstance(v, ListConfig) or isinstance(v, DictConfig): + v = OmegaConf.to_object(v) + setattr(args1, k, v) + else: + raise RuntimeError(f"Unknown argument {k}") + + return args1 + + +def all_exists(paths): + return all(os.path.exists(path) for path in paths) + + +def save_video(video, output_path, fps): + """ + Save a video to disk. + """ + if dist.is_initialized() and dist.get_rank() != 0: + return + os.makedirs(os.path.dirname(output_path), exist_ok=True) + imageio.mimwrite(output_path, video, fps=fps)