diff --git a/convert_weight_sat2hf.py b/convert_weight_sat2hf.py new file mode 100644 index 0000000..545925b --- /dev/null +++ b/convert_weight_sat2hf.py @@ -0,0 +1,303 @@ +""" + +The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format. +This script supports the conversion of the following models: +- CogVideoX-2B +- CogVideoX-5B, CogVideoX-5B-I2V +- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V + +Original Script: +https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py + +""" +import argparse +from typing import Any, Dict + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + #CogVideoXTransformer3DModel, +) +from custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel + + +def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): + to_q_key = key.replace("query_key_value", "to_q") + to_k_key = key.replace("query_key_value", "to_k") + to_v_key = key.replace("query_key_value", "to_v") + to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) + state_dict[to_q_key] = to_q + state_dict[to_k_key] = to_k + state_dict[to_v_key] = to_v + state_dict.pop(key) + + +def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, weight_or_bias = key.split(".")[-2:] + + if "query" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" + elif "key" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" + + state_dict[new_key] = state_dict.pop(key) + + +def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, _, weight_or_bias = key.split(".")[-3:] + + weights_or_biases = state_dict[key].chunk(12, dim=0) + norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) + norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) + + norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" + state_dict[norm1_key] = norm1_weights_or_biases + + norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" + state_dict[norm2_key] = norm2_weights_or_biases + + state_dict.pop(key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): + state_dict.pop(key) + + +def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): + key_split = key.split(".") + layer_index = int(key_split[2]) + replace_layer_index = 4 - 1 - layer_index + + key_split[1] = "up_blocks" + key_split[2] = str(replace_layer_index) + new_key = ".".join(key_split) + + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "transformer.final_layernorm": "norm_final", + "transformer": "transformer_blocks", + "attention": "attn1", + "mlp": "ff.net", + "dense_h_to_4h": "0.proj", + "dense_4h_to_h": "2", + ".layers": "", + "dense": "to_out.0", + "input_layernorm": "norm1.norm", + "post_attn1_layernorm": "norm2.norm", + "time_embed.0": "time_embedding.linear_1", + "time_embed.2": "time_embedding.linear_2", + "mixins.patch_embed": "patch_embed", + "mixins.final_layer.norm_final": "norm_out.norm", + "mixins.final_layer.linear": "proj_out", + "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", + "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "query_key_value": reassign_query_key_value_inplace, + "query_layernorm_list": reassign_query_key_layernorm_inplace, + "key_layernorm_list": reassign_query_key_layernorm_inplace, + "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, + "embed_tokens": remove_keys_inplace, + "freqs_sin": remove_keys_inplace, + "freqs_cos": remove_keys_inplace, + "position_embedding": remove_keys_inplace, +} + +VAE_KEYS_RENAME_DICT = { + "block.": "resnets.", + "down.": "down_blocks.", + "downsample": "downsamplers.0", + "upsample": "upsamplers.0", + "nin_shortcut": "conv_shortcut", + "encoder.mid.block_1": "encoder.mid_block.resnets.0", + "encoder.mid.block_2": "encoder.mid_block.resnets.1", + "decoder.mid.block_1": "decoder.mid_block.resnets.0", + "decoder.mid.block_2": "decoder.mid_block.resnets.1", +} + +VAE_SPECIAL_KEYS_REMAP = { + "loss": remove_keys_inplace, + "up.": replace_up_keys_inplace, +} + +TOKENIZER_MAX_LENGTH = 226 + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def convert_transformer( + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + i2v: bool, + dtype: torch.dtype, +): + PREFIX_KEY = "model.diffusion_model." + + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + transformer = CogVideoXTransformer3DModel( + in_channels=32 if i2v else 16, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + use_rotary_positional_embeddings=use_rotary_positional_embeddings, + use_learned_positional_embeddings=i2v, + ).to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[len(PREFIX_KEY):] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True) + return transformer + + +def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" + ) + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) + # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 + parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") + # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 + parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") + # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True + parser.add_argument( + "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" + ) + # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 + parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") + # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 + parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") + parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + transformer = None + vae = None + + if args.fp16 and args.bf16: + raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") + + dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer( + args.transformer_ckpt_path, + args.num_layers, + args.num_attention_heads, + args.use_rotary_positional_embeddings, + args.i2v, + dtype, + ) + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) + + #text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl" + #tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + #text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # Apparently, the conversion does not work anymore without this :shrug: + #for param in text_encoder.parameters(): + # param.data = param.data.contiguous() + + scheduler = CogVideoXDDIMScheduler.from_config( + { + "snr_shift_scale": args.snr_shift_scale, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": True, + "set_alpha_to_one": True, + "timestep_spacing": "trailing", + } + ) + if args.i2v: + pipeline_cls = CogVideoXImageToVideoPipeline + else: + pipeline_cls = CogVideoXPipeline + + pipe = pipeline_cls( + tokenizer=None, + text_encoder=None, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + + if args.fp16: + pipe = pipe.to(dtype=torch.float16) + if args.bf16: + pipe = pipe.to(dtype=torch.bfloat16) + + # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird + # for users to specify variant when the default is not fp32 and they want to run with the correct default (which + # is either fp16/bf16 here). + + # This is necessary This is necessary for users with insufficient memory, + # such as those using Colab and notebooks, as it can save some memory used for model loading. + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 09e6771..0e36cba 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -21,6 +21,8 @@ import torch.nn.functional as F import numpy as np from einops import rearrange +from functools import reduce +from operator import mul from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import logging @@ -32,6 +34,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero from diffusers.loaders import PeftAdapterMixin +from .embeddings import CogVideoX1_1PatchEmbed logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -61,6 +64,14 @@ def fft(tensor): return low_freq_fft, high_freq_fft +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + + class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -70,6 +81,16 @@ class CogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + def rotary(self, t, rope_args): + def reshape_freq(freqs): + freqs = freqs[: rope_args["T"], : rope_args["H"], : rope_args["W"]].contiguous() + freqs = rearrange(freqs, "t h w d -> (t h w) d") + freqs = freqs.unsqueeze(0).unsqueeze(0) + return freqs + freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype) + freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype) + + return t * freqs_cos + rotate_half(t) * freqs_sin @torch.compiler.disable() def __call__( self, @@ -78,6 +99,7 @@ class CogVideoXAttnProcessor2_0: encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + rope_args: Optional[dict] = None ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -107,13 +129,33 @@ class CogVideoXAttnProcessor2_0: if attn.norm_k is not None: key = attn.norm_k(key) + + # Apply RoPE if needed if image_rotary_emb is not None: + self.freqs_cos = image_rotary_emb[0] + self.freqs_sin = image_rotary_emb[1] + print("rope args", rope_args) #{'T': 6, 'H': 30, 'W': 45, 'seq_length': 8775} + print("freqs_cos", self.freqs_cos.shape) #torch.Size([13, 30, 45, 64]) + print("freqs_sin", self.freqs_sin.shape) + + from diffusers.models.embeddings import apply_rotary_emb - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + #query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + query = torch.cat( + (query[:, :, : text_seq_length], + self.rotary(query[:, :, text_seq_length:], + rope_args)), + dim=2) + if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + #key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + key = torch.cat( + (key[ :, :, : text_seq_length], + self.rotary(key[:, :, text_seq_length:], + rope_args)), + dim=2) if SAGEATTN_IS_AVAILABLE: hidden_states = sageattn(query, key, value, is_causal=False) @@ -303,6 +345,7 @@ class CogVideoXBlock(nn.Module): fastercache_counter=0, fastercache_start_step=15, fastercache_device="cuda:0", + rope_args=None ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) # norm & modulate @@ -335,7 +378,7 @@ class CogVideoXBlock(nn.Module): attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, + image_rotary_emb=image_rotary_emb, rope_args=rope_args ) if fastercache_counter == fastercache_start_step: self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] @@ -458,12 +501,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ) # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed( + self.patch_embed = CogVideoX1_1PatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=text_embed_dim, - bias=True, + #bias=True, sample_width=sample_width, sample_height=sample_height, sample_frames=sample_frames, @@ -507,7 +550,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): norm_eps=norm_eps, chunk_dim=1, ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -635,6 +678,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape + p = self.config.patch_size + print("p", p) # 1. Time embedding timesteps = timestep @@ -646,6 +691,18 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) + # RoPE + seq_length = num_frames * height * width // reduce(mul, [p, p, p]) + rope_T = num_frames // p + rope_H = height // p + rope_W = width // p + rope_args = { + "T": rope_T, + "H": rope_H, + "W": rope_W, + "seq_length": seq_length, + } + # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) @@ -696,7 +753,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # Note: we use `-1` instead of `channels`: # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size + output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) @@ -728,7 +785,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, fastercache_counter = self.fastercache_counter, - fastercache_device = self.fastercache_device + fastercache_device = self.fastercache_device, + rope_args=rope_args ) if (controlnet_states is not None) and (i < len(controlnet_states)): diff --git a/embeddings.py b/embeddings.py new file mode 100644 index 0000000..9747e91 --- /dev/null +++ b/embeddings.py @@ -0,0 +1,353 @@ +import torch +import torch.nn as nn +import numpy as np +from typing import Tuple, Union + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + +def get_3d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + #freqs = freqs.view( + # temporal_size * grid_size_h * grid_size_w, -1 + #) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + +def get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, +) -> np.ndarray: + r""" + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + return pos_embed + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + +class CogVideoX1_1PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 81, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, + use_learned_positional_embeddings: bool = True, + ) -> None: + super().__init__() + + # Adjust patch_size to handle three dimensions + self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width) + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + # Use Linear layer for projection + self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size[1] + post_patch_width = sample_width // self.patch_size[2] + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False) + joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) + + return joint_pos_embedding + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + """ + Args: + text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim). + image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + first_frame = image_embeds[:, 0:1, :, :, :] + duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width) + # Copy the first frames, for t_patch + image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1) + batch, num_frames, channels, height, width = image_embeds.shape + image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous() + image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1) + + rope_patch_t = num_frames // self.patch_size[0] + rope_patch_h = height // self.patch_size[1] + rope_patch_w = width // self.patch_size[2] + + image_embeds = image_embeds.view( + batch, + rope_patch_t, self.patch_size[0], + rope_patch_h, self.patch_size[1], + rope_patch_w, self.patch_size[2], + channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1) + image_embeds = self.proj(image_embeds) + # Concatenate text and image embeddings + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + + # Add positional embeddings if applicable + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): + raise ValueError( + "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." + "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + if ( + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + else: + pos_embedding = self.pos_embedding + + embeds = embeds + pos_embedding + + return embeds \ No newline at end of file diff --git a/model_loading.py b/model_loading.py index c3591c0..91c6ed3 100644 --- a/model_loading.py +++ b/model_loading.py @@ -71,6 +71,7 @@ class DownloadAndLoadCogVideoModel: "THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b", "THUDM/CogVideoX-5b-I2V", + "kijai/CogVideoX-5b-1.5-T2V", "bertjiazheng/KoolCogVideoX-5b", "kijai/CogVideoX-Fun-2b", "kijai/CogVideoX-Fun-5b", diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 9e38ea8..571498a 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -25,8 +25,9 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from diffusers.models.embeddings import get_3d_rotary_pos_embed +#from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.loaders import CogVideoXLoraLoaderMixin +from .embeddings import get_3d_rotary_pos_embed from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel