From e783951dadc9f672f67cd7c7234ca336e7eda7d0 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 9 Nov 2024 02:24:18 +0200 Subject: [PATCH] maybe --- custom_cogvideox_transformer_3d.py | 369 +++++++++++++++-------------- embeddings.py | 100 ++++---- pipeline_cogvideox.py | 35 +-- 3 files changed, 262 insertions(+), 242 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 0e36cba..ebad39e 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -21,20 +21,18 @@ import torch.nn.functional as F import numpy as np from einops import rearrange -from functools import reduce -from operator import mul from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import logging from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.attention import Attention, FeedForward from diffusers.models.attention_processor import AttentionProcessor -from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero from diffusers.loaders import PeftAdapterMixin -from .embeddings import CogVideoX1_1PatchEmbed +from .embeddings import CogVideoXPatchEmbed logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -64,14 +62,6 @@ def fft(tensor): return low_freq_fft, high_freq_fft -def rotate_half(x): - x = rearrange(x, "... (d r) -> ... d r", r=2) - x1, x2 = x.unbind(dim=-1) - x = torch.stack((-x2, x1), dim=-1) - return rearrange(x, "... d r -> ... (d r)") - - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -81,16 +71,7 @@ class CogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - def rotary(self, t, rope_args): - def reshape_freq(freqs): - freqs = freqs[: rope_args["T"], : rope_args["H"], : rope_args["W"]].contiguous() - freqs = rearrange(freqs, "t h w d -> (t h w) d") - freqs = freqs.unsqueeze(0).unsqueeze(0) - return freqs - freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype) - freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype) - - return t * freqs_cos + rotate_half(t) * freqs_sin + @torch.compiler.disable() def __call__( self, @@ -99,7 +80,6 @@ class CogVideoXAttnProcessor2_0: encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - rope_args: Optional[dict] = None ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -129,127 +109,118 @@ class CogVideoXAttnProcessor2_0: if attn.norm_k is not None: key = attn.norm_k(key) - - - # Apply RoPE if needed - if image_rotary_emb is not None: - self.freqs_cos = image_rotary_emb[0] - self.freqs_sin = image_rotary_emb[1] - print("rope args", rope_args) #{'T': 6, 'H': 30, 'W': 45, 'seq_length': 8775} - print("freqs_cos", self.freqs_cos.shape) #torch.Size([13, 30, 45, 64]) - print("freqs_sin", self.freqs_sin.shape) - - - from diffusers.models.embeddings import apply_rotary_emb - - #query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - query = torch.cat( - (query[:, :, : text_seq_length], - self.rotary(query[:, :, text_seq_length:], - rope_args)), - dim=2) - - if not attn.is_cross_attention: - #key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - key = torch.cat( - (key[ :, :, : text_seq_length], - self.rotary(key[:, :, text_seq_length:], - rope_args)), - dim=2) - - if SAGEATTN_IS_AVAILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - - -class FusedCogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - @torch.compiler.disable() - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb + has_nan = torch.isnan(query).any() + if has_nan: + raise ValueError(f"query before rope has nan: {has_nan}") - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - if SAGEATTN_IS_AVAILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + #if SAGEATTN_IS_AVAILABLE: + # hidden_states = sageattn(query, key, value, is_causal=False) + #else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + has_nan = torch.isnan(hidden_states).any() + if has_nan: + raise ValueError(f"hs after scaled_dot_product_attention has nan: {has_nan}") + has_inf = torch.isinf(hidden_states).any() + if has_inf: + raise ValueError(f"hs after scaled_dot_product_attention has inf: {has_inf}") hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # linear proj hidden_states = attn.to_out[0](hidden_states) + has_nan = torch.isnan(hidden_states).any() + # dropout hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states, hidden_states = hidden_states.split( [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 ) return hidden_states, encoder_hidden_states - + + +# class FusedCogVideoXAttnProcessor2_0: +# r""" +# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on +# query and key vectors, but does not include spatial normalization. +# """ + +# def __init__(self): +# if not hasattr(F, "scaled_dot_product_attention"): +# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") +# @torch.compiler.disable() +# def __call__( +# self, +# attn: Attention, +# hidden_states: torch.Tensor, +# encoder_hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# image_rotary_emb: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# print("FusedCogVideoXAttnProcessor2_0") +# text_seq_length = encoder_hidden_states.size(1) + +# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + +# batch_size, sequence_length, _ = ( +# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape +# ) + +# if attention_mask is not None: +# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) +# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + +# qkv = attn.to_qkv(hidden_states) +# split_size = qkv.shape[-1] // 3 +# query, key, value = torch.split(qkv, split_size, dim=-1) + +# inner_dim = key.shape[-1] +# head_dim = inner_dim // attn.heads + +# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) +# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) +# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + +# if attn.norm_q is not None: +# query = attn.norm_q(query) +# if attn.norm_k is not None: +# key = attn.norm_k(key) + +# # Apply RoPE if needed +# if image_rotary_emb is not None: +# from diffusers.models.embeddings import apply_rotary_emb + +# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) +# if not attn.is_cross_attention: +# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + +# if SAGEATTN_IS_AVAILABLE: +# hidden_states = sageattn(query, key, value, is_causal=False) +# else: +# hidden_states = F.scaled_dot_product_attention( +# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False +# ) + +# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + +# # linear proj +# hidden_states = attn.to_out[0](hidden_states) +# # dropout +# hidden_states = attn.to_out[1](hidden_states) + +# encoder_hidden_states, hidden_states = hidden_states.split( +# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 +# ) +# return hidden_states, encoder_hidden_states + +#region Blocks @maybe_allow_in_graph class CogVideoXBlock(nn.Module): @@ -344,14 +315,14 @@ class CogVideoXBlock(nn.Module): fuser=None, fastercache_counter=0, fastercache_start_step=15, - fastercache_device="cuda:0", - rope_args=None + fastercache_device="cuda:0" ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) + # Tora Motion-guidance Fuser if video_flow_feature is not None: H, W = video_flow_feature.shape[-2:] @@ -378,7 +349,7 @@ class CogVideoXBlock(nn.Module): attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, rope_args=rope_args + image_rotary_emb=image_rotary_emb ) if fastercache_counter == fastercache_start_step: self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] @@ -386,10 +357,18 @@ class CogVideoXBlock(nn.Module): elif fastercache_counter > fastercache_start_step: self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) - + + hidden_states = hidden_states + gate_msa * attn_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + # has_nan = torch.isnan(hidden_states).any() + # if has_nan: + # raise ValueError(f"hs before norm2 has nan: {has_nan}") + # has_inf = torch.isinf(hidden_states).any() + # if has_inf: + # raise ValueError(f"hs before norm2 has inf: {has_inf}") + # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb @@ -404,7 +383,7 @@ class CogVideoXBlock(nn.Module): return hidden_states, encoder_hidden_states - +#region Transformer class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). @@ -479,6 +458,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): sample_height: int = 60, sample_frames: int = 49, patch_size: int = 2, + patch_size_t: int = 2, temporal_compression_ratio: int = 4, max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", @@ -489,6 +469,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, + patch_bias: bool = True, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -501,12 +482,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ) # 1. Patch embedding - self.patch_embed = CogVideoX1_1PatchEmbed( + self.patch_embed = CogVideoXPatchEmbed( patch_size=patch_size, + patch_size_t=patch_size_t, in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=text_embed_dim, - #bias=True, + bias=patch_bias, sample_width=sample_width, sample_height=sample_height, sample_frames=sample_frames, @@ -550,7 +532,14 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): norm_eps=norm_eps, chunk_dim=1, ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels) + if patch_size_t is None: + # For CogVideox 1.0 + output_dim = patch_size * patch_size * out_channels + else: + # For CogVideoX 1.5 + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = nn.Linear(inner_dim, output_dim) self.gradient_checkpointing = False @@ -626,44 +615,44 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. + # def fuse_qkv_projections(self): + # """ + # Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + # are fused. For cross-attention modules, key and value projection matrices are fused. - + # - This API is 🧪 experimental. + # This API is 🧪 experimental. - - """ - self.original_attn_processors = None + # + # """ + # self.original_attn_processors = None - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + # for _, attn_processor in self.attn_processors.items(): + # if "Added" in str(attn_processor.__class__.__name__): + # raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - self.original_attn_processors = self.attn_processors + # self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + # for module in self.modules(): + # if isinstance(module, Attention): + # module.fuse_projections(fuse=True) - self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + # self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. + # def unfuse_qkv_projections(self): + # """Disables the fused QKV projection if enabled. - + # - This API is 🧪 experimental. + # This API is 🧪 experimental. - + # - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + # """ + # if self.original_attn_processors is not None: + # self.set_attn_processor(self.original_attn_processors) def forward( self, @@ -678,9 +667,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape - p = self.config.patch_size - print("p", p) - + # 1. Time embedding timesteps = timestep t_emb = self.time_proj(timesteps) @@ -691,25 +678,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) - # RoPE - seq_length = num_frames * height * width // reduce(mul, [p, p, p]) - rope_T = num_frames // p - rope_H = height // p - rope_W = width // p - rope_args = { - "T": rope_T, - "H": rope_H, - "W": rope_W, - "seq_length": seq_length, - } - # 2. Patch embedding + p = self.config.patch_size + p_t = self.config.patch_size_t + + # We know that the hidden states height and width will always be divisible by patch_size. + # But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames. + if p_t is not None: + remaining_frames = p_t - num_frames % p_t + first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) + hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) - + text_seq_length = encoder_hidden_states.shape[1] encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] + if self.use_fastercache: self.fastercache_counter+=1 if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: @@ -754,8 +740,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + output = output[:, remaining_frames:] (bb, tt, cc, hh, ww) = output.shape cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) @@ -777,6 +770,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): output = torch.cat([output, recovered_uncond]) else: for i, block in enumerate(self.transformer_blocks): + print("block", i) hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -785,9 +779,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, fastercache_counter = self.fastercache_counter, - fastercache_device = self.fastercache_device, - rope_args=rope_args + fastercache_device = self.fastercache_device ) + has_nan = torch.isnan(hidden_states).any() + if has_nan: + raise ValueError(f"block output hidden_states has nan: {has_nan}") if (controlnet_states is not None) and (i < len(controlnet_states)): controlnet_states_block = controlnet_states[i] @@ -816,9 +812,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # Note: we use `-1` instead of `channels`: # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + output = output[:, remaining_frames:] if self.fastercache_counter >= self.fastercache_start_step + 1: (bb, tt, cc, hh, ww) = output.shape diff --git a/embeddings.py b/embeddings.py index 9747e91..bc3bf7f 100644 --- a/embeddings.py +++ b/embeddings.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import numpy as np -from typing import Tuple, Union +from typing import Tuple, Union, Optional def get_1d_rotary_pos_embed( dim: int, @@ -123,9 +123,9 @@ def get_3d_rotary_pos_embed( freqs = torch.cat( [freqs_t, freqs_h, freqs_w], dim=-1 ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) - #freqs = freqs.view( - # temporal_size * grid_size_h * grid_size_w, -1 - #) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) return freqs t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t @@ -236,16 +236,18 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb -class CogVideoX1_1PatchEmbed(nn.Module): +class CogVideoXPatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, + patch_size_t: Optional[int] = None, in_channels: int = 16, embed_dim: int = 1920, text_embed_dim: int = 4096, + bias: bool = True, sample_width: int = 90, sample_height: int = 60, - sample_frames: int = 81, + sample_frames: int = 49, temporal_compression_ratio: int = 4, max_text_seq_length: int = 226, spatial_interpolation_scale: float = 1.875, @@ -255,8 +257,8 @@ class CogVideoX1_1PatchEmbed(nn.Module): ) -> None: super().__init__() - # Adjust patch_size to handle three dimensions - self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width) + self.patch_size = patch_size + self.patch_size_t = patch_size_t self.embed_dim = embed_dim self.sample_height = sample_height self.sample_width = sample_width @@ -268,8 +270,15 @@ class CogVideoX1_1PatchEmbed(nn.Module): self.use_positional_embeddings = use_positional_embeddings self.use_learned_positional_embeddings = use_learned_positional_embeddings - # Use Linear layer for projection - self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim) + if patch_size_t is None: + # CogVideoX 1.0 checkpoints + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + else: + # CogVideoX 1.5 checkpoints + self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) if use_positional_embeddings or use_learned_positional_embeddings: @@ -278,8 +287,8 @@ class CogVideoX1_1PatchEmbed(nn.Module): self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: - post_patch_height = sample_height // self.patch_size[1] - post_patch_width = sample_width // self.patch_size[2] + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 num_patches = post_patch_height * post_patch_width * post_time_compression_frames @@ -291,44 +300,46 @@ class CogVideoX1_1PatchEmbed(nn.Module): self.temporal_interpolation_scale, ) pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) - joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False) - joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) + joint_pos_embedding = torch.zeros( + 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) return joint_pos_embedding def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - """ + r""" Args: - text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim). - image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width). + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ text_embeds = self.text_proj(text_embeds) - first_frame = image_embeds[:, 0:1, :, :, :] - duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width) - # Copy the first frames, for t_patch - image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1) - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous() - image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1) - rope_patch_t = num_frames // self.patch_size[0] - rope_patch_h = height // self.patch_size[1] - rope_patch_w = width // self.patch_size[2] + batch_size, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.view( - batch, - rope_patch_t, self.patch_size[0], - rope_patch_h, self.patch_size[1], - rope_patch_w, self.patch_size[2], - channels - ) - image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() - image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1) - image_embeds = self.proj(image_embeds) - # Concatenate text and image embeddings - embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + else: + p = self.patch_size + p_t = self.patch_size_t + + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds) + + embeds = torch.cat( + [text_embeds, image_embeds], dim=1 + ).contiguous() # [batch, seq_length + num_frames x height x width, channels] - # Add positional embeddings if applicable if self.use_positional_embeddings or self.use_learned_positional_embeddings: if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): raise ValueError( @@ -339,9 +350,9 @@ class CogVideoX1_1PatchEmbed(nn.Module): pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 if ( - self.sample_height != height - or self.sample_width != width - or self.sample_frames != pre_time_compression_frames + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames ): pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) @@ -350,4 +361,5 @@ class CogVideoX1_1PatchEmbed(nn.Module): embeds = embeds + pos_embedding - return embeds \ No newline at end of file + return embeds + \ No newline at end of file diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 571498a..a563b73 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -21,6 +21,7 @@ import torch.nn.functional as F import math from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor @@ -115,7 +116,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps -class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for text-to-video generation using CogVideoX. @@ -298,18 +299,18 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) return weights - def fuse_qkv_projections(self) -> None: - r"""Enables fused QKV projections.""" - self.fusing_transformer = True - self.transformer.fuse_qkv_projections() + # def fuse_qkv_projections(self) -> None: + # r"""Enables fused QKV projections.""" + # self.fusing_transformer = True + # self.transformer.fuse_qkv_projections() - def unfuse_qkv_projections(self) -> None: - r"""Disable QKV projection fusion if enabled.""" - if not self.fusing_transformer: - logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") - else: - self.transformer.unfuse_qkv_projections() - self.fusing_transformer = False + # def unfuse_qkv_projections(self) -> None: + # r"""Disable QKV projection fusion if enabled.""" + # if not self.fusing_transformer: + # logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + # else: + # self.transformer.unfuse_qkv_projections() + # self.fusing_transformer = False def _prepare_rotary_positional_embeddings( self, @@ -322,8 +323,12 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -332,7 +337,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, use_real=True, )