mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
605 lines
27 KiB
Python
605 lines
27 KiB
Python
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
import os
|
|
import json
|
|
import torch
|
|
import glob
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.utils import is_torch_version, logging
|
|
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
from diffusers.models.attention import Attention, FeedForward
|
|
from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
class CogVideoXPatchEmbed(nn.Module):
|
|
def __init__(
|
|
self,
|
|
patch_size: int = 2,
|
|
in_channels: int = 16,
|
|
embed_dim: int = 1920,
|
|
text_embed_dim: int = 4096,
|
|
bias: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
|
|
self.proj = nn.Conv2d(
|
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
|
)
|
|
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
|
|
|
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
|
r"""
|
|
Args:
|
|
text_embeds (`torch.Tensor`):
|
|
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
|
image_embeds (`torch.Tensor`):
|
|
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
|
"""
|
|
text_embeds = self.text_proj(text_embeds)
|
|
|
|
batch, num_frames, channels, height, width = image_embeds.shape
|
|
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
|
image_embeds = self.proj(image_embeds)
|
|
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
|
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
|
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
|
|
|
embeds = torch.cat(
|
|
[text_embeds, image_embeds], dim=1
|
|
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
|
return embeds
|
|
|
|
@maybe_allow_in_graph
|
|
class CogVideoXBlock(nn.Module):
|
|
r"""
|
|
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
|
|
|
Parameters:
|
|
dim (`int`):
|
|
The number of channels in the input and output.
|
|
num_attention_heads (`int`):
|
|
The number of heads to use for multi-head attention.
|
|
attention_head_dim (`int`):
|
|
The number of channels in each head.
|
|
time_embed_dim (`int`):
|
|
The number of channels in timestep embedding.
|
|
dropout (`float`, defaults to `0.0`):
|
|
The dropout probability to use.
|
|
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
|
Activation function to be used in feed-forward.
|
|
attention_bias (`bool`, defaults to `False`):
|
|
Whether or not to use bias in attention projection layers.
|
|
qk_norm (`bool`, defaults to `True`):
|
|
Whether or not to use normalization after query and key projections in Attention.
|
|
norm_elementwise_affine (`bool`, defaults to `True`):
|
|
Whether to use learnable elementwise affine parameters for normalization.
|
|
norm_eps (`float`, defaults to `1e-5`):
|
|
Epsilon value for normalization layers.
|
|
final_dropout (`bool` defaults to `False`):
|
|
Whether to apply a final dropout after the last feed-forward layer.
|
|
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
|
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
|
ff_bias (`bool`, defaults to `True`):
|
|
Whether or not to use bias in Feed-forward layer.
|
|
attention_out_bias (`bool`, defaults to `True`):
|
|
Whether or not to use bias in Attention output projection layer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
time_embed_dim: int,
|
|
dropout: float = 0.0,
|
|
activation_fn: str = "gelu-approximate",
|
|
attention_bias: bool = False,
|
|
qk_norm: bool = True,
|
|
norm_elementwise_affine: bool = True,
|
|
norm_eps: float = 1e-5,
|
|
final_dropout: bool = True,
|
|
ff_inner_dim: Optional[int] = None,
|
|
ff_bias: bool = True,
|
|
attention_out_bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
# 1. Self Attention
|
|
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
|
|
|
self.attn1 = Attention(
|
|
query_dim=dim,
|
|
dim_head=attention_head_dim,
|
|
heads=num_attention_heads,
|
|
qk_norm="layer_norm" if qk_norm else None,
|
|
eps=1e-6,
|
|
bias=attention_bias,
|
|
out_bias=attention_out_bias,
|
|
processor=CogVideoXAttnProcessor2_0(),
|
|
)
|
|
|
|
# 2. Feed Forward
|
|
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
|
|
|
self.ff = FeedForward(
|
|
dim,
|
|
dropout=dropout,
|
|
activation_fn=activation_fn,
|
|
final_dropout=final_dropout,
|
|
inner_dim=ff_inner_dim,
|
|
bias=ff_bias,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
temb: torch.Tensor,
|
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
) -> torch.Tensor:
|
|
text_seq_length = encoder_hidden_states.size(1)
|
|
|
|
# norm & modulate
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
|
hidden_states, encoder_hidden_states, temb
|
|
)
|
|
|
|
# attention
|
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
|
hidden_states=norm_hidden_states,
|
|
encoder_hidden_states=norm_encoder_hidden_states,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
|
|
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
|
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
|
|
|
# norm & modulate
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
|
hidden_states, encoder_hidden_states, temb
|
|
)
|
|
|
|
# feed-forward
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
ff_output = self.ff(norm_hidden_states)
|
|
|
|
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
|
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
|
|
|
return hidden_states, encoder_hidden_states
|
|
|
|
|
|
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
"""
|
|
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
|
|
|
Parameters:
|
|
num_attention_heads (`int`, defaults to `30`):
|
|
The number of heads to use for multi-head attention.
|
|
attention_head_dim (`int`, defaults to `64`):
|
|
The number of channels in each head.
|
|
in_channels (`int`, defaults to `16`):
|
|
The number of channels in the input.
|
|
out_channels (`int`, *optional*, defaults to `16`):
|
|
The number of channels in the output.
|
|
flip_sin_to_cos (`bool`, defaults to `True`):
|
|
Whether to flip the sin to cos in the time embedding.
|
|
time_embed_dim (`int`, defaults to `512`):
|
|
Output dimension of timestep embeddings.
|
|
text_embed_dim (`int`, defaults to `4096`):
|
|
Input dimension of text embeddings from the text encoder.
|
|
num_layers (`int`, defaults to `30`):
|
|
The number of layers of Transformer blocks to use.
|
|
dropout (`float`, defaults to `0.0`):
|
|
The dropout probability to use.
|
|
attention_bias (`bool`, defaults to `True`):
|
|
Whether or not to use bias in the attention projection layers.
|
|
sample_width (`int`, defaults to `90`):
|
|
The width of the input latents.
|
|
sample_height (`int`, defaults to `60`):
|
|
The height of the input latents.
|
|
sample_frames (`int`, defaults to `49`):
|
|
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
|
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
|
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
|
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
|
patch_size (`int`, defaults to `2`):
|
|
The size of the patches to use in the patch embedding layer.
|
|
temporal_compression_ratio (`int`, defaults to `4`):
|
|
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
|
max_text_seq_length (`int`, defaults to `226`):
|
|
The maximum sequence length of the input text embeddings.
|
|
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
|
Activation function to use in feed-forward.
|
|
timestep_activation_fn (`str`, defaults to `"silu"`):
|
|
Activation function to use when generating the timestep embeddings.
|
|
norm_elementwise_affine (`bool`, defaults to `True`):
|
|
Whether or not to use elementwise affine in normalization layers.
|
|
norm_eps (`float`, defaults to `1e-5`):
|
|
The epsilon value to use in normalization layers.
|
|
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
|
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
|
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
|
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
|
"""
|
|
|
|
_supports_gradient_checkpointing = True
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
num_attention_heads: int = 30,
|
|
attention_head_dim: int = 64,
|
|
in_channels: int = 16,
|
|
out_channels: Optional[int] = 16,
|
|
flip_sin_to_cos: bool = True,
|
|
freq_shift: int = 0,
|
|
time_embed_dim: int = 512,
|
|
text_embed_dim: int = 4096,
|
|
num_layers: int = 30,
|
|
dropout: float = 0.0,
|
|
attention_bias: bool = True,
|
|
sample_width: int = 90,
|
|
sample_height: int = 60,
|
|
sample_frames: int = 49,
|
|
patch_size: int = 2,
|
|
temporal_compression_ratio: int = 4,
|
|
max_text_seq_length: int = 226,
|
|
activation_fn: str = "gelu-approximate",
|
|
timestep_activation_fn: str = "silu",
|
|
norm_elementwise_affine: bool = True,
|
|
norm_eps: float = 1e-5,
|
|
spatial_interpolation_scale: float = 1.875,
|
|
temporal_interpolation_scale: float = 1.0,
|
|
use_rotary_positional_embeddings: bool = False,
|
|
):
|
|
super().__init__()
|
|
inner_dim = num_attention_heads * attention_head_dim
|
|
|
|
post_patch_height = sample_height // patch_size
|
|
post_patch_width = sample_width // patch_size
|
|
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
|
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
|
self.post_patch_height = post_patch_height
|
|
self.post_patch_width = post_patch_width
|
|
self.post_time_compression_frames = post_time_compression_frames
|
|
self.patch_size = patch_size
|
|
|
|
# 1. Patch embedding
|
|
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
|
self.embedding_dropout = nn.Dropout(dropout)
|
|
|
|
# 2. 3D positional embeddings
|
|
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
|
inner_dim,
|
|
(post_patch_width, post_patch_height),
|
|
post_time_compression_frames,
|
|
spatial_interpolation_scale,
|
|
temporal_interpolation_scale,
|
|
)
|
|
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
|
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
|
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
|
|
|
# 3. Time embeddings
|
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
|
|
|
# 4. Define spatio-temporal transformers blocks
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
CogVideoXBlock(
|
|
dim=inner_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
time_embed_dim=time_embed_dim,
|
|
dropout=dropout,
|
|
activation_fn=activation_fn,
|
|
attention_bias=attention_bias,
|
|
norm_elementwise_affine=norm_elementwise_affine,
|
|
norm_eps=norm_eps,
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
|
|
|
# 5. Output blocks
|
|
self.norm_out = AdaLayerNorm(
|
|
embedding_dim=time_embed_dim,
|
|
output_dim=2 * inner_dim,
|
|
norm_elementwise_affine=norm_elementwise_affine,
|
|
norm_eps=norm_eps,
|
|
chunk_dim=1,
|
|
)
|
|
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
self.gradient_checkpointing = value
|
|
|
|
@property
|
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
|
r"""
|
|
Returns:
|
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
|
indexed by its weight name.
|
|
"""
|
|
# set recursively
|
|
processors = {}
|
|
|
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
|
if hasattr(module, "get_processor"):
|
|
processors[f"{name}.processor"] = module.get_processor()
|
|
|
|
for sub_name, child in module.named_children():
|
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
|
|
|
return processors
|
|
|
|
for name, module in self.named_children():
|
|
fn_recursive_add_processors(name, module, processors)
|
|
|
|
return processors
|
|
|
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
|
r"""
|
|
Sets the attention processor to use to compute attention.
|
|
|
|
Parameters:
|
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
|
for **all** `Attention` layers.
|
|
|
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
|
processor. This is strongly recommended when setting trainable attention processors.
|
|
|
|
"""
|
|
count = len(self.attn_processors.keys())
|
|
|
|
if isinstance(processor, dict) and len(processor) != count:
|
|
raise ValueError(
|
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
|
)
|
|
|
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
|
if hasattr(module, "set_processor"):
|
|
if not isinstance(processor, dict):
|
|
module.set_processor(processor)
|
|
else:
|
|
module.set_processor(processor.pop(f"{name}.processor"))
|
|
|
|
for sub_name, child in module.named_children():
|
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
|
|
|
for name, module in self.named_children():
|
|
fn_recursive_attn_processor(name, module, processor)
|
|
|
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
|
def fuse_qkv_projections(self):
|
|
"""
|
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
|
are fused. For cross-attention modules, key and value projection matrices are fused.
|
|
|
|
<Tip warning={true}>
|
|
|
|
This API is 🧪 experimental.
|
|
|
|
</Tip>
|
|
"""
|
|
self.original_attn_processors = None
|
|
|
|
for _, attn_processor in self.attn_processors.items():
|
|
if "Added" in str(attn_processor.__class__.__name__):
|
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
|
|
|
self.original_attn_processors = self.attn_processors
|
|
|
|
for module in self.modules():
|
|
if isinstance(module, Attention):
|
|
module.fuse_projections(fuse=True)
|
|
|
|
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
|
|
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
|
def unfuse_qkv_projections(self):
|
|
"""Disables the fused QKV projection if enabled.
|
|
|
|
<Tip warning={true}>
|
|
|
|
This API is 🧪 experimental.
|
|
|
|
</Tip>
|
|
|
|
"""
|
|
if self.original_attn_processors is not None:
|
|
self.set_attn_processor(self.original_attn_processors)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
timestep: Union[int, float, torch.LongTensor],
|
|
timestep_cond: Optional[torch.Tensor] = None,
|
|
inpaint_latents: Optional[torch.Tensor] = None,
|
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
return_dict: bool = True,
|
|
):
|
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
|
|
|
# 1. Time embedding
|
|
timesteps = timestep
|
|
t_emb = self.time_proj(timesteps)
|
|
|
|
# timesteps does not contain any weights and will always return f32 tensors
|
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
|
# there might be better ways to encapsulate this.
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
|
emb = self.time_embedding(t_emb, timestep_cond)
|
|
|
|
# 2. Patch embedding
|
|
if inpaint_latents is not None:
|
|
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
|
|
|
# 3. Position embedding
|
|
text_seq_length = encoder_hidden_states.shape[1]
|
|
if not self.config.use_rotary_positional_embeddings:
|
|
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
|
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
|
pos_embeds = self.pos_embedding
|
|
emb_size = hidden_states.size()[-1]
|
|
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
|
|
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
|
|
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
|
|
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
|
|
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
|
|
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
|
|
hidden_states = hidden_states + pos_embeds
|
|
hidden_states = self.embedding_dropout(hidden_states)
|
|
|
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
|
hidden_states = hidden_states[:, text_seq_length:]
|
|
|
|
# 4. Transformer blocks
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
emb,
|
|
image_rotary_emb,
|
|
**ckpt_kwargs,
|
|
)
|
|
else:
|
|
hidden_states, encoder_hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=emb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
|
|
if not self.config.use_rotary_positional_embeddings:
|
|
# CogVideoX-2B
|
|
hidden_states = self.norm_final(hidden_states)
|
|
else:
|
|
# CogVideoX-5B
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
hidden_states = self.norm_final(hidden_states)
|
|
hidden_states = hidden_states[:, text_seq_length:]
|
|
|
|
# 5. Final block
|
|
hidden_states = self.norm_out(hidden_states, temb=emb)
|
|
hidden_states = self.proj_out(hidden_states)
|
|
|
|
# 6. Unpatchify
|
|
p = self.config.patch_size
|
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
return Transformer2DModelOutput(sample=output)
|
|
|
|
@classmethod
|
|
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
|
|
if subfolder is not None:
|
|
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
|
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
|
|
|
config_file = os.path.join(pretrained_model_path, 'config.json')
|
|
if not os.path.isfile(config_file):
|
|
raise RuntimeError(f"{config_file} does not exist")
|
|
with open(config_file, "r") as f:
|
|
config = json.load(f)
|
|
|
|
from diffusers.utils import WEIGHTS_NAME
|
|
model = cls.from_config(config, **transformer_additional_kwargs)
|
|
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
|
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
|
if os.path.exists(model_file):
|
|
state_dict = torch.load(model_file, map_location="cpu")
|
|
elif os.path.exists(model_file_safetensors):
|
|
from safetensors.torch import load_file, safe_open
|
|
state_dict = load_file(model_file_safetensors)
|
|
else:
|
|
from safetensors.torch import load_file, safe_open
|
|
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
|
state_dict = {}
|
|
for model_file_safetensors in model_files_safetensors:
|
|
_state_dict = load_file(model_file_safetensors)
|
|
for key in _state_dict:
|
|
state_dict[key] = _state_dict[key]
|
|
|
|
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
|
|
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
|
|
if len(new_shape) == 5:
|
|
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
|
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
|
|
else:
|
|
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
|
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
|
|
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
|
|
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
|
else:
|
|
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
|
|
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
|
|
|
tmp_state_dict = {}
|
|
for key in state_dict:
|
|
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
|
tmp_state_dict[key] = state_dict[key]
|
|
else:
|
|
print(key, "Size don't match, skip")
|
|
state_dict = tmp_state_dict
|
|
|
|
m, u = model.load_state_dict(state_dict, strict=False)
|
|
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
|
print(m)
|
|
|
|
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
|
|
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
|
|
|
|
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
|
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
|
|
|
return model |