mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
doesn't work yet
This commit is contained in:
parent
1cc6e1f070
commit
2074ba578e
303
convert_weight_sat2hf.py
Normal file
303
convert_weight_sat2hf.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""
|
||||
|
||||
The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
|
||||
This script supports the conversion of the following models:
|
||||
- CogVideoX-2B
|
||||
- CogVideoX-5B, CogVideoX-5B-I2V
|
||||
- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
|
||||
|
||||
Original Script:
|
||||
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
|
||||
|
||||
"""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXDDIMScheduler,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXPipeline,
|
||||
#CogVideoXTransformer3DModel,
|
||||
)
|
||||
from custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||
|
||||
|
||||
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
||||
to_q_key = key.replace("query_key_value", "to_q")
|
||||
to_k_key = key.replace("query_key_value", "to_k")
|
||||
to_v_key = key.replace("query_key_value", "to_v")
|
||||
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
|
||||
state_dict[to_q_key] = to_q
|
||||
state_dict[to_k_key] = to_k
|
||||
state_dict[to_v_key] = to_v
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
|
||||
layer_id, weight_or_bias = key.split(".")[-2:]
|
||||
|
||||
if "query" in key:
|
||||
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
|
||||
elif "key" in key:
|
||||
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
|
||||
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
||||
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
|
||||
layer_id, _, weight_or_bias = key.split(".")[-3:]
|
||||
|
||||
weights_or_biases = state_dict[key].chunk(12, dim=0)
|
||||
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
|
||||
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
|
||||
|
||||
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
|
||||
state_dict[norm1_key] = norm1_weights_or_biases
|
||||
|
||||
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
|
||||
state_dict[norm2_key] = norm2_weights_or_biases
|
||||
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
||||
key_split = key.split(".")
|
||||
layer_index = int(key_split[2])
|
||||
replace_layer_index = 4 - 1 - layer_index
|
||||
|
||||
key_split[1] = "up_blocks"
|
||||
key_split[2] = str(replace_layer_index)
|
||||
new_key = ".".join(key_split)
|
||||
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"transformer.final_layernorm": "norm_final",
|
||||
"transformer": "transformer_blocks",
|
||||
"attention": "attn1",
|
||||
"mlp": "ff.net",
|
||||
"dense_h_to_4h": "0.proj",
|
||||
"dense_4h_to_h": "2",
|
||||
".layers": "",
|
||||
"dense": "to_out.0",
|
||||
"input_layernorm": "norm1.norm",
|
||||
"post_attn1_layernorm": "norm2.norm",
|
||||
"time_embed.0": "time_embedding.linear_1",
|
||||
"time_embed.2": "time_embedding.linear_2",
|
||||
"mixins.patch_embed": "patch_embed",
|
||||
"mixins.final_layer.norm_final": "norm_out.norm",
|
||||
"mixins.final_layer.linear": "proj_out",
|
||||
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"query_key_value": reassign_query_key_value_inplace,
|
||||
"query_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
||||
"embed_tokens": remove_keys_inplace,
|
||||
"freqs_sin": remove_keys_inplace,
|
||||
"freqs_cos": remove_keys_inplace,
|
||||
"position_embedding": remove_keys_inplace,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
"block.": "resnets.",
|
||||
"down.": "down_blocks.",
|
||||
"downsample": "downsamplers.0",
|
||||
"upsample": "upsamplers.0",
|
||||
"nin_shortcut": "conv_shortcut",
|
||||
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
|
||||
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
|
||||
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
|
||||
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"loss": remove_keys_inplace,
|
||||
"up.": replace_up_keys_inplace,
|
||||
}
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 226
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
i2v: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
in_channels=32 if i2v else 16,
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=i2v,
|
||||
).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY):]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
||||
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
||||
parser.add_argument(
|
||||
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
||||
)
|
||||
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
vae = None
|
||||
|
||||
if args.fp16 and args.bf16:
|
||||
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
|
||||
|
||||
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(
|
||||
args.transformer_ckpt_path,
|
||||
args.num_layers,
|
||||
args.num_attention_heads,
|
||||
args.use_rotary_positional_embeddings,
|
||||
args.i2v,
|
||||
dtype,
|
||||
)
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||
|
||||
#text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
|
||||
#tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
#text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# Apparently, the conversion does not work anymore without this :shrug:
|
||||
#for param in text_encoder.parameters():
|
||||
# param.data = param.data.contiguous()
|
||||
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(
|
||||
{
|
||||
"snr_shift_scale": args.snr_shift_scale,
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": False,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "v_prediction",
|
||||
"rescale_betas_zero_snr": True,
|
||||
"set_alpha_to_one": True,
|
||||
"timestep_spacing": "trailing",
|
||||
}
|
||||
)
|
||||
if args.i2v:
|
||||
pipeline_cls = CogVideoXImageToVideoPipeline
|
||||
else:
|
||||
pipeline_cls = CogVideoXPipeline
|
||||
|
||||
pipe = pipeline_cls(
|
||||
tokenizer=None,
|
||||
text_encoder=None,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
if args.fp16:
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
if args.bf16:
|
||||
pipe = pipe.to(dtype=torch.bfloat16)
|
||||
|
||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||
# is either fp16/bf16 here).
|
||||
|
||||
# This is necessary This is necessary for users with insufficient memory,
|
||||
# such as those using Colab and notebooks, as it can save some memory used for model loading.
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
||||
@ -21,6 +21,8 @@ import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import logging
|
||||
@ -32,6 +34,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from .embeddings import CogVideoX1_1PatchEmbed
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@ -61,6 +64,14 @@ def fft(tensor):
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@ -70,6 +81,16 @@ class CogVideoXAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
def rotary(self, t, rope_args):
|
||||
def reshape_freq(freqs):
|
||||
freqs = freqs[: rope_args["T"], : rope_args["H"], : rope_args["W"]].contiguous()
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
||||
return freqs
|
||||
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
|
||||
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
|
||||
|
||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||
@torch.compiler.disable()
|
||||
def __call__(
|
||||
self,
|
||||
@ -78,6 +99,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
rope_args: Optional[dict] = None
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
@ -107,13 +129,33 @@ class CogVideoXAttnProcessor2_0:
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
self.freqs_cos = image_rotary_emb[0]
|
||||
self.freqs_sin = image_rotary_emb[1]
|
||||
print("rope args", rope_args) #{'T': 6, 'H': 30, 'W': 45, 'seq_length': 8775}
|
||||
print("freqs_cos", self.freqs_cos.shape) #torch.Size([13, 30, 45, 64])
|
||||
print("freqs_sin", self.freqs_sin.shape)
|
||||
|
||||
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
#query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
query = torch.cat(
|
||||
(query[:, :, : text_seq_length],
|
||||
self.rotary(query[:, :, text_seq_length:],
|
||||
rope_args)),
|
||||
dim=2)
|
||||
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
#key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
key = torch.cat(
|
||||
(key[ :, :, : text_seq_length],
|
||||
self.rotary(key[:, :, text_seq_length:],
|
||||
rope_args)),
|
||||
dim=2)
|
||||
|
||||
if SAGEATTN_IS_AVAILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
@ -303,6 +345,7 @@ class CogVideoXBlock(nn.Module):
|
||||
fastercache_counter=0,
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0",
|
||||
rope_args=None
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
# norm & modulate
|
||||
@ -335,7 +378,7 @@ class CogVideoXBlock(nn.Module):
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
image_rotary_emb=image_rotary_emb, rope_args=rope_args
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||
@ -458,12 +501,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
self.patch_embed = CogVideoX1_1PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
text_embed_dim=text_embed_dim,
|
||||
bias=True,
|
||||
#bias=True,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
@ -507,7 +550,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@ -635,6 +678,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
p = self.config.patch_size
|
||||
print("p", p)
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
@ -646,6 +691,18 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# RoPE
|
||||
seq_length = num_frames * height * width // reduce(mul, [p, p, p])
|
||||
rope_T = num_frames // p
|
||||
rope_H = height // p
|
||||
rope_W = width // p
|
||||
rope_args = {
|
||||
"T": rope_T,
|
||||
"H": rope_H,
|
||||
"W": rope_W,
|
||||
"seq_length": seq_length,
|
||||
}
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
@ -696,7 +753,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# Note: we use `-1` instead of `channels`:
|
||||
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
||||
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
||||
p = self.config.patch_size
|
||||
|
||||
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
@ -728,7 +785,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_device = self.fastercache_device
|
||||
fastercache_device = self.fastercache_device,
|
||||
rope_args=rope_args
|
||||
)
|
||||
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
|
||||
353
embeddings.py
Normal file
353
embeddings.py
Normal file
@ -0,0 +1,353 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import Tuple, Union
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||||
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||||
data type.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
if use_real is not True:
|
||||
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||
|
||||
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||
freqs_t = freqs_t[:, None, None, :].expand(
|
||||
-1, grid_size_h, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
||||
freqs_h = freqs_h[None, :, None, :].expand(
|
||||
temporal_size, -1, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
||||
freqs_w = freqs_w[None, None, :, :].expand(
|
||||
temporal_size, grid_size_h, -1, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
||||
|
||||
freqs = torch.cat(
|
||||
[freqs_t, freqs_h, freqs_w], dim=-1
|
||||
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
||||
#freqs = freqs.view(
|
||||
# temporal_size * grid_size_h * grid_size_w, -1
|
||||
#) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
||||
return freqs
|
||||
|
||||
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||
return cos, sin
|
||||
|
||||
def get_3d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
r"""
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
temporal_size (`int`):
|
||||
spatial_interpolation_scale (`float`, defaults to 1.0):
|
||||
temporal_interpolation_scale (`float`, defaults to 1.0):
|
||||
"""
|
||||
if embed_dim % 4 != 0:
|
||||
raise ValueError("`embed_dim` must be divisible by 4")
|
||||
if isinstance(spatial_size, int):
|
||||
spatial_size = (spatial_size, spatial_size)
|
||||
|
||||
embed_dim_spatial = 3 * embed_dim // 4
|
||||
embed_dim_temporal = embed_dim // 4
|
||||
|
||||
# 1. Spatial
|
||||
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
|
||||
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
||||
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
||||
|
||||
# 2. Temporal
|
||||
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
|
||||
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
||||
|
||||
# 3. Concat
|
||||
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
||||
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
||||
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
|
||||
|
||||
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
|
||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
||||
"""
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
class CogVideoX1_1PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 81,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_positional_embeddings: bool = True,
|
||||
use_learned_positional_embeddings: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Adjust patch_size to handle three dimensions
|
||||
self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width)
|
||||
self.embed_dim = embed_dim
|
||||
self.sample_height = sample_height
|
||||
self.sample_width = sample_width
|
||||
self.sample_frames = sample_frames
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.use_positional_embeddings = use_positional_embeddings
|
||||
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
||||
|
||||
# Use Linear layer for projection
|
||||
self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim)
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
if use_positional_embeddings or use_learned_positional_embeddings:
|
||||
persistent = use_learned_positional_embeddings
|
||||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
||||
|
||||
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
||||
post_patch_height = sample_height // self.patch_size[1]
|
||||
post_patch_width = sample_width // self.patch_size[2]
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.embed_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
)
|
||||
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
||||
joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
|
||||
|
||||
return joint_pos_embedding
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim).
|
||||
image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width).
|
||||
"""
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
first_frame = image_embeds[:, 0:1, :, :, :]
|
||||
duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width)
|
||||
# Copy the first frames, for t_patch
|
||||
image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1)
|
||||
batch, num_frames, channels, height, width = image_embeds.shape
|
||||
image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous()
|
||||
image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1)
|
||||
|
||||
rope_patch_t = num_frames // self.patch_size[0]
|
||||
rope_patch_h = height // self.patch_size[1]
|
||||
rope_patch_w = width // self.patch_size[2]
|
||||
|
||||
image_embeds = image_embeds.view(
|
||||
batch,
|
||||
rope_patch_t, self.patch_size[0],
|
||||
rope_patch_h, self.patch_size[1],
|
||||
rope_patch_w, self.patch_size[2],
|
||||
channels
|
||||
)
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||
image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
# Concatenate text and image embeddings
|
||||
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
|
||||
|
||||
# Add positional embeddings if applicable
|
||||
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
||||
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
|
||||
raise ValueError(
|
||||
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
|
||||
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
|
||||
if (
|
||||
self.sample_height != height
|
||||
or self.sample_width != width
|
||||
or self.sample_frames != pre_time_compression_frames
|
||||
):
|
||||
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
||||
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
embeds = embeds + pos_embedding
|
||||
|
||||
return embeds
|
||||
@ -71,6 +71,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
"THUDM/CogVideoX-2b",
|
||||
"THUDM/CogVideoX-5b",
|
||||
"THUDM/CogVideoX-5b-I2V",
|
||||
"kijai/CogVideoX-5b-1.5-T2V",
|
||||
"bertjiazheng/KoolCogVideoX-5b",
|
||||
"kijai/CogVideoX-Fun-2b",
|
||||
"kijai/CogVideoX-Fun-5b",
|
||||
|
||||
@ -25,8 +25,9 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
#from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||
from .embeddings import get_3d_rotary_pos_embed
|
||||
|
||||
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user