mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 21:34:22 +08:00
code cleanup
codebase getting too bloated: drop PAB support in favor of FasterCache drop temporal tilling in favor of FreeNoise
This commit is contained in:
parent
e8a289112f
commit
0bd3da569e
@ -1,741 +0,0 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import glob
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import is_torch_version, logging
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
#from diffusers.models.attention_processor import AttentionProcessor
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
#from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
|
||||
from ..videosys.modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
from ..videosys.modules.embeddings import apply_rotary_emb
|
||||
from ..videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGEATTN_IS_AVAVILABLE = True
|
||||
logger.info("Using sageattn")
|
||||
except:
|
||||
logger.info("sageattn not found, using sdpa")
|
||||
SAGEATTN_IS_AVAVILABLE = False
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# if attn.parallel_manager.sp_size > 1:
|
||||
# assert (
|
||||
# attn.heads % attn.parallel_manager.sp_size == 0
|
||||
# ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
|
||||
# attn_heads = attn.heads // attn.parallel_manager.sp_size
|
||||
# query, key, value = map(
|
||||
# lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
|
||||
# [query, key, value],
|
||||
# )
|
||||
|
||||
attn_heads = attn.heads
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn_heads
|
||||
|
||||
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
emb_len = image_rotary_emb[0].shape[0]
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
||||
|
||||
#if attn.parallel_manager.sp_size > 1:
|
||||
# hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||
r"""
|
||||
Args:
|
||||
text_embeds (`torch.Tensor`):
|
||||
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
||||
image_embeds (`torch.Tensor`):
|
||||
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
||||
"""
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
|
||||
batch, num_frames, channels, height, width = image_embeds.shape
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||
|
||||
embeds = torch.cat(
|
||||
[text_embeds, image_embeds], dim=1
|
||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
||||
return embeds
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Feed-forward layer.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Attention output projection layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
time_embed_dim: int,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = True,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# parallel
|
||||
#self.attn1.parallel_manager = None
|
||||
|
||||
# 2. Feed Forward
|
||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# pab
|
||||
self.attn_count = 0
|
||||
self.last_attn = None
|
||||
self.block_idx = block_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep=None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
if enable_pab():
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
||||
if enable_pab() and broadcast_attn:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
if enable_pab():
|
||||
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `30`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
sample_frames (`int`, defaults to `49`):
|
||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
temporal_compression_ratio (`int`, defaults to `4`):
|
||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
||||
max_text_seq_length (`int`, defaults to `226`):
|
||||
The maximum sequence length of the input text embeddings.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
time_embed_dim: int = 512,
|
||||
text_embed_dim: int = 4096,
|
||||
num_layers: int = 30,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 49,
|
||||
patch_size: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
timestep_activation_fn: str = "silu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
add_noise_in_inpaint_model: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
post_patch_height = sample_height // patch_size
|
||||
post_patch_width = sample_width // patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
self.post_patch_height = post_patch_height
|
||||
self.post_patch_width = post_patch_width
|
||||
self.post_time_compression_frames = post_time_compression_frames
|
||||
self.patch_size = patch_size
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. 3D positional embeddings
|
||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||
inner_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
spatial_interpolation_scale,
|
||||
temporal_interpolation_scale,
|
||||
)
|
||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
# 3. Time embeddings
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
|
||||
# 4. Define spatio-temporal transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CogVideoXBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# 5. Output blocks
|
||||
self.norm_out = AdaLayerNorm(
|
||||
embedding_dim=time_embed_dim,
|
||||
output_dim=2 * inner_dim,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
inpaint_latents: Optional[torch.Tensor] = None,
|
||||
control_latents: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. Patch embedding
|
||||
if inpaint_latents is not None:
|
||||
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
||||
if control_latents is not None:
|
||||
hidden_states = torch.concat([hidden_states, control_latents], 2)
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
pos_embeds = self.pos_embedding
|
||||
emb_size = hidden_states.size()[-1]
|
||||
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
|
||||
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
|
||||
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
|
||||
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
|
||||
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
|
||||
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 4. Transformer blocks
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep=timestep,
|
||||
)
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
|
||||
if subfolder is not None:
|
||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
||||
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
||||
|
||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
||||
if not os.path.isfile(config_file):
|
||||
raise RuntimeError(f"{config_file} does not exist")
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
from diffusers.utils import WEIGHTS_NAME
|
||||
model = cls.from_config(config, **transformer_additional_kwargs)
|
||||
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
||||
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
elif os.path.exists(model_file_safetensors):
|
||||
from safetensors.torch import load_file, safe_open
|
||||
state_dict = load_file(model_file_safetensors)
|
||||
else:
|
||||
from safetensors.torch import load_file, safe_open
|
||||
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
||||
state_dict = {}
|
||||
for model_file_safetensors in model_files_safetensors:
|
||||
_state_dict = load_file(model_file_safetensors)
|
||||
for key in _state_dict:
|
||||
state_dict[key] = _state_dict[key]
|
||||
|
||||
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
|
||||
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
|
||||
if len(new_shape) == 5:
|
||||
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
||||
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
|
||||
else:
|
||||
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
||||
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
|
||||
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
|
||||
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
||||
else:
|
||||
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
|
||||
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
||||
|
||||
tmp_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
||||
tmp_state_dict[key] = state_dict[key]
|
||||
else:
|
||||
print(key, "Size don't match, skip")
|
||||
state_dict = tmp_state_dict
|
||||
|
||||
m, u = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
||||
print(m)
|
||||
|
||||
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
|
||||
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
|
||||
|
||||
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
||||
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
||||
|
||||
return model
|
||||
@ -33,10 +33,6 @@ from diffusers.video_processor import VideoProcessor
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from einops import rearrange
|
||||
|
||||
from ..videosys.core.pipeline import VideoSysPipeline
|
||||
from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
||||
from ..videosys.core.pab_mgr import set_pab_manager
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@ -158,7 +154,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
|
||||
videos: torch.Tensor
|
||||
|
||||
|
||||
class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
@ -188,7 +184,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
pab_config = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -210,9 +205,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
|
||||
if pab_config is not None:
|
||||
set_pab_manager(pab_config)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
|
||||
latents=None, freenoise=True, context_size=None, context_overlap=None
|
||||
@ -348,16 +340,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def _gaussian_weights(self, t_tile_length, t_batch_size):
|
||||
from numpy import pi, exp, sqrt
|
||||
|
||||
var = 0.01
|
||||
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
||||
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
|
||||
weights = torch.tensor(t_probs)
|
||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
||||
return weights
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
@ -697,24 +679,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 8.5. Temporal tiling prep
|
||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
||||
t_tile_length = context_frames
|
||||
t_tile_overlap = context_overlap
|
||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||
use_temporal_tiling = True
|
||||
print("Temporal tiling enabled")
|
||||
elif context_schedule is not None:
|
||||
if context_schedule is not None:
|
||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = True
|
||||
from .context import get_context_scheduler
|
||||
context = get_context_scheduler(context_schedule)
|
||||
|
||||
else:
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = False
|
||||
print("Temporal tiling and context schedule disabled")
|
||||
print(" context schedule disabled")
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
@ -735,88 +708,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||
# =====================================================
|
||||
grid_ts = 0
|
||||
cur_t = 0
|
||||
while cur_t < latents.shape[1]:
|
||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
||||
grid_ts += 1
|
||||
|
||||
all_t = latents.shape[1]
|
||||
latents_all_list = []
|
||||
# =====================================================
|
||||
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||
control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :]
|
||||
|
||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||
|
||||
#t_input = t[None].to(device)
|
||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input_tile,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=t_input,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
control_latents=control_latents_tile,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
||||
latents_all_list.append(latents_tile)
|
||||
|
||||
# ==========================================
|
||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
# Add each tile contribution to overall latents
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
||||
|
||||
latents_all /= contributors
|
||||
|
||||
latents = latents_all
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
pbar.update(1)
|
||||
# ==========================================
|
||||
elif use_context_schedule:
|
||||
if use_context_schedule:
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@ -33,11 +33,6 @@ from diffusers.video_processor import VideoProcessor
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from einops import rearrange
|
||||
|
||||
from ..videosys.core.pipeline import VideoSysPipeline
|
||||
from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
||||
from ..videosys.core.pab_mgr import set_pab_manager
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@ -206,7 +201,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
|
||||
videos: torch.Tensor
|
||||
|
||||
|
||||
class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
@ -236,7 +231,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
pab_config = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -258,9 +252,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
|
||||
if pab_config is not None:
|
||||
set_pab_manager(pab_config)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@ -433,16 +424,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def _gaussian_weights(self, t_tile_length, t_batch_size):
|
||||
from numpy import pi, exp, sqrt
|
||||
|
||||
var = 0.01
|
||||
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
||||
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
|
||||
weights = torch.tensor(t_probs)
|
||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
||||
return weights
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
@ -866,22 +847,14 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
||||
t_tile_length = context_frames
|
||||
t_tile_overlap = context_overlap
|
||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||
use_temporal_tiling = True
|
||||
print("Temporal tiling enabled")
|
||||
elif context_schedule is not None:
|
||||
if context_schedule is not None:
|
||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = True
|
||||
from .context import get_context_scheduler
|
||||
context = get_context_scheduler(context_schedule)
|
||||
else:
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = False
|
||||
print("Temporal tiling and context schedule disabled")
|
||||
print("context schedule disabled")
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
@ -915,87 +888,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||
# =====================================================
|
||||
grid_ts = 0
|
||||
cur_t = 0
|
||||
while cur_t < latents.shape[1]:
|
||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
||||
grid_ts += 1
|
||||
|
||||
all_t = latents.shape[1]
|
||||
latents_all_list = []
|
||||
# =====================================================
|
||||
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||
inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :]
|
||||
|
||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||
|
||||
#t_input = t[None].to(device)
|
||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input_tile,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=t_input,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
inpaint_latents=inpaint_latents_tile,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
||||
latents_all_list.append(latents_tile)
|
||||
|
||||
# ==========================================
|
||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
# Add each tile contribution to overall latents
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
||||
|
||||
latents_all /= contributors
|
||||
|
||||
latents = latents_all
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
pbar.update(1)
|
||||
# ==========================================
|
||||
elif use_context_schedule:
|
||||
if use_context_schedule:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@ -1133,18 +1026,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
else:
|
||||
pbar.update(1)
|
||||
|
||||
# if output_type == "numpy":
|
||||
# video = self.decode_latents(latents)
|
||||
# elif not output_type == "latent":
|
||||
# video = self.decode_latents(latents)
|
||||
# video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
# else:
|
||||
# video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
# if not return_dict:
|
||||
# video = torch.from_numpy(video)
|
||||
|
||||
return latents
|
||||
@ -12,15 +12,12 @@ from .pipeline_cogvideox import CogVideoXPipeline
|
||||
from contextlib import nullcontext
|
||||
|
||||
from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun
|
||||
from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFunPAB
|
||||
from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun
|
||||
|
||||
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
|
||||
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
|
||||
|
||||
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
||||
|
||||
from .utils import check_diffusers_version, remove_specific_blocks, log
|
||||
from .utils import remove_specific_blocks, log
|
||||
from comfy.utils import load_torch_file
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -95,7 +92,6 @@ class DownloadAndLoadCogVideoModel:
|
||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||
"lora": ("COGLORA", {"default": None}),
|
||||
"compile_args":("COMPILEARGS", ),
|
||||
@ -111,7 +107,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
|
||||
|
||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
|
||||
enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
|
||||
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
||||
attention_mode="sdpa", load_device="main_device"):
|
||||
|
||||
if precision == "fp16" and "1.5" in model:
|
||||
@ -188,15 +184,9 @@ class DownloadAndLoadCogVideoModel:
|
||||
|
||||
# transformer
|
||||
if "Fun" in model:
|
||||
if pab_config is not None:
|
||||
transformer = CogVideoXTransformer3DModelFunPAB.from_pretrained(base_path, subfolder=subfolder)
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
|
||||
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
|
||||
else:
|
||||
if pab_config is not None:
|
||||
transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder=subfolder)
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
||||
|
||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||
|
||||
@ -213,12 +203,12 @@ class DownloadAndLoadCogVideoModel:
|
||||
if "Fun" in model:
|
||||
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
if "Pose" in model:
|
||||
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
|
||||
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
|
||||
else:
|
||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
|
||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
|
||||
else:
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||
if "cogvideox-2b-img2vid" in model:
|
||||
pipe.input_with_padding = False
|
||||
|
||||
@ -296,7 +286,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
backend="nexfort",
|
||||
options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
|
||||
ignores=["vae"],
|
||||
fuse_qkv_projections=True if pab_config is None else False,
|
||||
fuse_qkv_projections= False,
|
||||
)
|
||||
|
||||
pipeline = {
|
||||
@ -334,7 +324,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||
},
|
||||
"optional": {
|
||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||
#"lora": ("COGLORA", {"default": None}),
|
||||
"compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||
@ -348,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
|
||||
pab_config=None, block_edit=None, compile="disabled", attention_mode="sdpa"):
|
||||
block_edit=None, compile="disabled", attention_mode="sdpa"):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -396,10 +385,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
transformer_config["in_channels"] = 32
|
||||
else:
|
||||
transformer_config["in_channels"] = 33
|
||||
if pab_config is not None:
|
||||
transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config)
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
|
||||
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
|
||||
elif "I2V" in model or "Interpolation" in model:
|
||||
transformer_config["in_channels"] = 32
|
||||
if "1_5" in model:
|
||||
@ -409,16 +395,10 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
transformer_config["patch_bias"] = False
|
||||
transformer_config["sample_height"] = 96
|
||||
transformer_config["sample_width"] = 170
|
||||
if pab_config is not None:
|
||||
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
else:
|
||||
transformer_config["in_channels"] = 16
|
||||
if pab_config is not None:
|
||||
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
|
||||
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
||||
if "2b" in model:
|
||||
@ -476,13 +456,13 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
|
||||
vae.load_state_dict(vae_sd)
|
||||
if "Pose" in model:
|
||||
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
|
||||
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
|
||||
else:
|
||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
|
||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
|
||||
else:
|
||||
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
|
||||
vae.load_state_dict(vae_sd)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
64
nodes.py
64
nodes.py
@ -44,8 +44,6 @@ from PIL import Image
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if not "CogVideo" in folder_paths.folder_names_and_paths:
|
||||
@ -53,61 +51,11 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
|
||||
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
||||
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
||||
|
||||
#PAB
|
||||
from .videosys.pab import CogVideoXPABConfig
|
||||
|
||||
class CogVideoPABConfig:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}),
|
||||
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}),
|
||||
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}),
|
||||
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||
|
||||
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PAB_CONFIG",)
|
||||
RETURN_NAMES = ("pab_config", )
|
||||
FUNCTION = "config"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use"
|
||||
|
||||
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
|
||||
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
|
||||
cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps):
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
pab_config = CogVideoXPABConfig(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=[spatial_threshold_end, spatial_threshold_start],
|
||||
spatial_range=spatial_range,
|
||||
temporal_broadcast=temporal_broadcast,
|
||||
temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
|
||||
temporal_range=temporal_range,
|
||||
cross_broadcast=cross_broadcast,
|
||||
cross_threshold=[cross_threshold_end, cross_threshold_start],
|
||||
cross_range=cross_range
|
||||
)
|
||||
|
||||
return (pab_config, )
|
||||
|
||||
class CogVideoContextOptions:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],),
|
||||
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],),
|
||||
"context_frames": ("INT", {"default": 48, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
|
||||
"context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||
@ -1152,9 +1100,6 @@ class CogVideoXFunSampler:
|
||||
end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
|
||||
|
||||
# Load Sampler
|
||||
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
|
||||
log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
|
||||
scheduler="CogVideoXDDIM"
|
||||
scheduler_config = pipeline["scheduler_config"]
|
||||
if scheduler in scheduler_mapping:
|
||||
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||
@ -1282,7 +1227,7 @@ class CogVideoXFunControlSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
|
||||
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
|
||||
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0,
|
||||
samples=None, denoise_strength=1.0, context_options=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -1306,9 +1251,6 @@ class CogVideoXFunControlSampler:
|
||||
|
||||
# Load Sampler
|
||||
scheduler_config = pipeline["scheduler_config"]
|
||||
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
|
||||
log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
|
||||
scheduler="CogVideoXDDIM"
|
||||
if scheduler in scheduler_mapping:
|
||||
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||
pipe.scheduler = noise_scheduler
|
||||
@ -1427,7 +1369,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
|
||||
"CogVideoXFunControlSampler": CogVideoXFunControlSampler,
|
||||
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
||||
"CogVideoPABConfig": CogVideoPABConfig,
|
||||
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
||||
"CogVideoControlImageEncode": CogVideoControlImageEncode,
|
||||
"CogVideoContextOptions": CogVideoContextOptions,
|
||||
@ -1450,7 +1391,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
|
||||
"CogVideoXFunControlSampler": "CogVideoXFun Control Sampler",
|
||||
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
||||
"CogVideoPABConfig": "CogVideo PABConfig",
|
||||
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
||||
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
||||
"CogVideoContextOptions": "CogVideo Context Options",
|
||||
|
||||
@ -20,8 +20,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel
|
||||
#from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.models import AutoencoderKLCogVideoX
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
@ -35,10 +35,6 @@ from comfy.utils import ProgressBar
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
from .videosys.core.pipeline import VideoSysPipeline
|
||||
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
||||
from .videosys.core.pab_mgr import set_pab_manager
|
||||
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
@ -115,7 +111,7 @@ def retrieve_timesteps(
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
@ -144,10 +140,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB],
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
original_mask = None,
|
||||
pab_config = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -164,9 +159,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.video_processor.config.do_resize = False
|
||||
|
||||
if pab_config is not None:
|
||||
set_pab_manager(pab_config)
|
||||
|
||||
self.input_with_padding = True
|
||||
|
||||
|
||||
@ -289,29 +281,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps.to(device), num_inference_steps - t_start
|
||||
|
||||
def _gaussian_weights(self, t_tile_length, t_batch_size):
|
||||
from numpy import pi, exp, sqrt
|
||||
|
||||
var = 0.01
|
||||
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
||||
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
|
||||
weights = torch.tensor(t_probs)
|
||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
||||
return weights
|
||||
|
||||
# def fuse_qkv_projections(self) -> None:
|
||||
# r"""Enables fused QKV projections."""
|
||||
# self.fusing_transformer = True
|
||||
# self.transformer.fuse_qkv_projections()
|
||||
|
||||
# def unfuse_qkv_projections(self) -> None:
|
||||
# r"""Disable QKV projection fusion if enabled."""
|
||||
# if not self.fusing_transformer:
|
||||
# logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
# else:
|
||||
# self.transformer.unfuse_qkv_projections()
|
||||
# self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
@ -365,8 +334,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
num_frames: int = 48,
|
||||
t_tile_length: int = 12,
|
||||
t_tile_overlap: int = 4,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
@ -487,9 +454,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
num_frames += self.additional_frames * self.vae_scale_factor_temporal
|
||||
|
||||
|
||||
#if latents is None and num_frames == t_tile_length:
|
||||
# num_frames += 1
|
||||
|
||||
if self.original_mask is not None:
|
||||
image_latents = latents
|
||||
original_image_latents = image_latents
|
||||
@ -569,23 +533,16 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7. context schedule and temporal tiling
|
||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
||||
t_tile_length = context_frames
|
||||
t_tile_overlap = context_overlap
|
||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||
use_temporal_tiling = True
|
||||
logger.info("Temporal tiling enabled")
|
||||
elif context_schedule is not None:
|
||||
if context_schedule is not None:
|
||||
if image_cond_latents is not None:
|
||||
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
||||
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = True
|
||||
from .cogvideox_fun.context import get_context_scheduler
|
||||
context = get_context_scheduler(context_schedule)
|
||||
#todo ofs embeds?
|
||||
|
||||
else:
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = False
|
||||
logger.info("Temporal tiling and context schedule disabled")
|
||||
# 7.5. Create rotary embeds if required
|
||||
@ -647,100 +604,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||
# =====================================================
|
||||
grid_ts = 0
|
||||
cur_t = 0
|
||||
while cur_t < latents.shape[1]:
|
||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
||||
grid_ts += 1
|
||||
|
||||
all_t = latents.shape[1]
|
||||
latents_all_list = []
|
||||
# =====================================================
|
||||
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||
|
||||
#t_input = t[None].to(device)
|
||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input_tile,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=t_input,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
||||
latents_all_list.append(latents_tile)
|
||||
|
||||
# ==========================================
|
||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
# Add each tile contribution to overall latents
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
||||
|
||||
latents_all /= contributors
|
||||
|
||||
latents = latents_all
|
||||
#print("latents",latents.shape)
|
||||
# start diff diff
|
||||
if i < len(timesteps) - 1 and self.original_mask is not None:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
mask = mask.to(latents)
|
||||
ts_from = timesteps[0]
|
||||
ts_to = timesteps[-1]
|
||||
threshold = (t - ts_to) / (ts_from - ts_to)
|
||||
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
|
||||
latents = image_latent * mask + latents * (1 - mask)
|
||||
# end diff diff
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
comfy_pbar.update(1)
|
||||
# ==========================================
|
||||
elif use_context_schedule:
|
||||
# region context schedule sampling
|
||||
if use_context_schedule:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
counter = torch.zeros_like(latent_model_input)
|
||||
@ -858,7 +723,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
comfy_pbar.update(1)
|
||||
|
||||
|
||||
# region sampling
|
||||
else:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@ -1,621 +0,0 @@
|
||||
# Adapted from CogVideo
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# CogVideo: https://github.com/THUDM/CogVideo
|
||||
# diffusers: https://github.com/huggingface/diffusers
|
||||
# --------------------------------------------------------
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils import is_torch_version
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from torch import nn
|
||||
|
||||
from .core.pab_mgr import enable_pab, if_broadcast_spatial
|
||||
from .modules.embeddings import apply_rotary_emb
|
||||
|
||||
#from .modules.embeddings import CogVideoXPatchEmbed
|
||||
|
||||
from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGEATTN_IS_AVAVILABLE = True
|
||||
except:
|
||||
SAGEATTN_IS_AVAVILABLE = False
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
@torch.compiler.disable()
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
attn_heads = attn.heads
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn_heads
|
||||
|
||||
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
emb_len = image_rotary_emb[0].shape[0]
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
@torch.compiler.disable()
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Feed-forward layer.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Attention output projection layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
time_embed_dim: int,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = True,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# parallel
|
||||
#self.attn1.parallel_manager = None
|
||||
|
||||
# 2. Feed Forward
|
||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# pab
|
||||
self.attn_count = 0
|
||||
self.last_attn = None
|
||||
self.block_idx = block_idx
|
||||
#@torch.compiler.disable()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep=None,
|
||||
video_flow_feature: Optional[torch.Tensor] = None,
|
||||
fuser=None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
# Tora Motion-guidance Fuser
|
||||
if video_flow_feature is not None:
|
||||
H, W = video_flow_feature.shape[-2:]
|
||||
T = norm_hidden_states.shape[1] // H // W
|
||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
|
||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||
del h, fuser
|
||||
# attention
|
||||
if enable_pab():
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
||||
if enable_pab() and broadcast_attn:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
if enable_pab():
|
||||
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `30`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
sample_frames (`int`, defaults to `49`):
|
||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
temporal_compression_ratio (`int`, defaults to `4`):
|
||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
||||
max_text_seq_length (`int`, defaults to `226`):
|
||||
The maximum sequence length of the input text embeddings.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
time_embed_dim: int = 512,
|
||||
text_embed_dim: int = 4096,
|
||||
num_layers: int = 30,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 49,
|
||||
patch_size: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
timestep_activation_fn: str = "silu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
use_learned_positional_embeddings: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
post_patch_height = sample_height // patch_size
|
||||
post_patch_width = sample_width // patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
text_embed_dim=text_embed_dim,
|
||||
bias=True,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
max_text_seq_length=max_text_seq_length,
|
||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||
)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. 3D positional embeddings
|
||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||
inner_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
spatial_interpolation_scale,
|
||||
temporal_interpolation_scale,
|
||||
)
|
||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
# 3. Time embeddings
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
|
||||
# 4. Define spatio-temporal transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CogVideoXBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# 5. Output blocks
|
||||
self.norm_out = AdaLayerNorm(
|
||||
embedding_dim=time_embed_dim,
|
||||
output_dim=2 * inner_dim,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.fuser_list = None
|
||||
|
||||
# parallel
|
||||
#self.parallel_manager = None
|
||||
|
||||
# def enable_parallel(self, dp_size, sp_size, enable_cp):
|
||||
# # update cfg parallel
|
||||
# if enable_cp and sp_size % 2 == 0:
|
||||
# sp_size = sp_size // 2
|
||||
# cp_size = 2
|
||||
# else:
|
||||
# cp_size = 1
|
||||
|
||||
# self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
# for _, module in self.named_modules():
|
||||
# if hasattr(module, "parallel_manager"):
|
||||
# module.parallel_manager = self.parallel_manager
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
controlnet_states: torch.Tensor = None,
|
||||
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
|
||||
video_flow_features: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# if self.parallel_manager.cp_size > 1:
|
||||
# (
|
||||
# hidden_states,
|
||||
# encoder_hidden_states,
|
||||
# timestep,
|
||||
# timestep_cond,
|
||||
# image_rotary_emb,
|
||||
# ) = batch_func(
|
||||
# partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
# hidden_states,
|
||||
# encoder_hidden_states,
|
||||
# timestep,
|
||||
# timestep_cond,
|
||||
# image_rotary_emb,
|
||||
# )
|
||||
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# if self.parallel_manager.sp_size > 1:
|
||||
# set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
||||
# hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep=timesteps if enable_pab() else None,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
)
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
controlnet_block_weight = 1.0
|
||||
if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
|
||||
controlnet_block_weight = controlnet_weights[i]
|
||||
elif isinstance(controlnet_weights, (float, int)):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
|
||||
#if self.parallel_manager.sp_size > 1:
|
||||
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
#if self.parallel_manager.cp_size > 1:
|
||||
# output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@ -1,232 +0,0 @@
|
||||
|
||||
PAB_MANAGER = None
|
||||
|
||||
|
||||
class PABConfig:
|
||||
def __init__(
|
||||
self,
|
||||
steps: int,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = None,
|
||||
cross_range: int = None,
|
||||
spatial_broadcast: bool = False,
|
||||
spatial_threshold: list = None,
|
||||
spatial_range: int = None,
|
||||
temporal_broadcast: bool = False,
|
||||
temporal_threshold: list = None,
|
||||
temporal_range: int = None,
|
||||
mlp_broadcast: bool = False,
|
||||
mlp_spatial_broadcast_config: dict = None,
|
||||
mlp_temporal_broadcast_config: dict = None,
|
||||
):
|
||||
self.steps = steps
|
||||
|
||||
self.cross_broadcast = cross_broadcast
|
||||
self.cross_threshold = cross_threshold
|
||||
self.cross_range = cross_range
|
||||
|
||||
self.spatial_broadcast = spatial_broadcast
|
||||
self.spatial_threshold = spatial_threshold
|
||||
self.spatial_range = spatial_range
|
||||
|
||||
self.temporal_broadcast = temporal_broadcast
|
||||
self.temporal_threshold = temporal_threshold
|
||||
self.temporal_range = temporal_range
|
||||
|
||||
self.mlp_broadcast = mlp_broadcast
|
||||
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
||||
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
||||
self.mlp_temporal_outputs = {}
|
||||
self.mlp_spatial_outputs = {}
|
||||
|
||||
|
||||
class PABManager:
|
||||
def __init__(self, config: PABConfig):
|
||||
self.config: PABConfig = config
|
||||
|
||||
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
|
||||
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
|
||||
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
|
||||
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
|
||||
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
|
||||
print(init_prompt)
|
||||
|
||||
def if_broadcast_cross(self, timestep: int, count: int):
|
||||
if (
|
||||
self.config.cross_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.cross_range != 0)
|
||||
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
def if_broadcast_temporal(self, timestep: int, count: int):
|
||||
if (
|
||||
self.config.temporal_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.temporal_range != 0)
|
||||
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
|
||||
if (
|
||||
self.config.spatial_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.spatial_range != 0)
|
||||
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
@staticmethod
|
||||
def _is_t_in_skip_config(all_timesteps, timestep, config):
|
||||
is_t_in_skip_config = False
|
||||
skip_range = None
|
||||
for key in config:
|
||||
if key not in all_timesteps:
|
||||
continue
|
||||
index = all_timesteps.index(key)
|
||||
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
|
||||
if timestep in skip_range:
|
||||
is_t_in_skip_config = True
|
||||
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
|
||||
break
|
||||
return is_t_in_skip_config, skip_range
|
||||
|
||||
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
||||
if not self.config.mlp_broadcast:
|
||||
return False, None, False, None
|
||||
|
||||
if is_temporal:
|
||||
cur_config = self.config.mlp_temporal_broadcast_config
|
||||
else:
|
||||
cur_config = self.config.mlp_spatial_broadcast_config
|
||||
|
||||
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
|
||||
next_flag = False
|
||||
if (
|
||||
self.config.mlp_broadcast
|
||||
and (timestep is not None)
|
||||
and (timestep in cur_config)
|
||||
and (block_idx in cur_config[timestep]["block"])
|
||||
):
|
||||
flag = False
|
||||
next_flag = True
|
||||
count = count + 1
|
||||
elif (
|
||||
self.config.mlp_broadcast
|
||||
and (timestep is not None)
|
||||
and (is_t_in_skip_config)
|
||||
and (block_idx in cur_config[skip_range[0]]["block"])
|
||||
):
|
||||
flag = True
|
||||
count = 0
|
||||
else:
|
||||
flag = False
|
||||
|
||||
return flag, count, next_flag, skip_range
|
||||
|
||||
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
|
||||
if is_temporal:
|
||||
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
|
||||
else:
|
||||
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
|
||||
|
||||
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
|
||||
skip_start_t = skip_range[0]
|
||||
if is_temporal:
|
||||
skip_output = (
|
||||
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
|
||||
if self.config.mlp_temporal_outputs is not None
|
||||
else None
|
||||
)
|
||||
else:
|
||||
skip_output = (
|
||||
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
|
||||
if self.config.mlp_spatial_outputs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if skip_output is not None:
|
||||
if timestep == skip_range[-1]:
|
||||
# TODO: save memory
|
||||
if is_temporal:
|
||||
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
|
||||
else:
|
||||
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
|
||||
)
|
||||
|
||||
return skip_output
|
||||
|
||||
def get_spatial_mlp_outputs(self):
|
||||
return self.config.mlp_spatial_outputs
|
||||
|
||||
def get_temporal_mlp_outputs(self):
|
||||
return self.config.mlp_temporal_outputs
|
||||
|
||||
|
||||
def set_pab_manager(config: PABConfig):
|
||||
global PAB_MANAGER
|
||||
PAB_MANAGER = PABManager(config)
|
||||
|
||||
|
||||
def enable_pab():
|
||||
if PAB_MANAGER is None:
|
||||
return False
|
||||
return (
|
||||
PAB_MANAGER.config.cross_broadcast
|
||||
or PAB_MANAGER.config.spatial_broadcast
|
||||
or PAB_MANAGER.config.temporal_broadcast
|
||||
)
|
||||
|
||||
|
||||
def update_steps(steps: int):
|
||||
if PAB_MANAGER is not None:
|
||||
PAB_MANAGER.config.steps = steps
|
||||
|
||||
|
||||
def if_broadcast_cross(timestep: int, count: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_cross(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_temporal(timestep: int, count: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
||||
|
||||
|
||||
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
|
||||
|
||||
|
||||
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
|
||||
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
|
||||
|
||||
|
||||
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
|
||||
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|
||||
@ -1,44 +0,0 @@
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
class VideoSysPipeline(DiffusionPipeline):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def set_eval_and_device(device: torch.device, *modules):
|
||||
for module in modules:
|
||||
module.eval()
|
||||
module.to(device)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
In diffusers, it is a convention to call the pipeline object.
|
||||
But in VideoSys, we will use the generate method for better prompt.
|
||||
This is a wrapper for the generate method to support the diffusers usage.
|
||||
"""
|
||||
return self.generate(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
# modify: remove the config module from the expected modules
|
||||
expected_modules = expected_modules - {"config"}
|
||||
|
||||
optional_names = list(optional_parameters)
|
||||
for name in optional_names:
|
||||
if name in cls._optional_components:
|
||||
expected_modules.add(name)
|
||||
optional_parameters.remove(name)
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
@ -1,3 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
@ -1,71 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CogVideoXDownsample3D(nn.Module):
|
||||
# Todo: Wait for paper relase.
|
||||
r"""
|
||||
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of channels in the input image.
|
||||
out_channels (`int`):
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size (`int`, defaults to `3`):
|
||||
Size of the convolving kernel.
|
||||
stride (`int`, defaults to `2`):
|
||||
Stride of the convolution.
|
||||
padding (`int`, defaults to `0`):
|
||||
Padding added to all four sides of the input.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to compress the time dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 2,
|
||||
padding: int = 0,
|
||||
compress_time: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
# Pad the tensor
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||
x = self.conv(x)
|
||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
@ -1,308 +0,0 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from einops import rearrange
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||
r"""
|
||||
Args:
|
||||
text_embeds (`torch.Tensor`):
|
||||
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
||||
image_embeds (`torch.Tensor`):
|
||||
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
||||
"""
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
|
||||
batch, num_frames, channels, height, width = image_embeds.shape
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||
|
||||
embeds = torch.cat(
|
||||
[text_embeds, image_embeds], dim=1
|
||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
||||
return embeds
|
||||
|
||||
|
||||
class OpenSoraPatchEmbed3D(nn.Module):
|
||||
"""Video to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: (2,4,4).
|
||||
in_chans (int): Number of input video channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 4, 4),
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.flatten = flatten
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, D, H, W = x.size()
|
||||
if W % self.patch_size[2] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||
if H % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||
if D % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # (B C T H W)
|
||||
if self.norm is not None:
|
||||
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||
freqs = freqs.to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
if t_freq.dtype != dtype:
|
||||
t_freq = t_freq.to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.outdim = hidden_size
|
||||
|
||||
def forward(self, s, bs):
|
||||
if s.ndim == 1:
|
||||
s = s[:, None]
|
||||
assert s.ndim == 2
|
||||
if s.shape[0] != bs:
|
||||
s = s.repeat(bs // s.shape[0], 1)
|
||||
assert s.shape[0] == bs
|
||||
b, dims = s.shape[0], s.shape[1]
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
|
||||
s_emb = self.mlp(s_freq)
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Use for example in Lumina
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Use for example in Stable Audio
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
@ -1,85 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_dim: int,
|
||||
embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||
output_dim (`int`, *optional*):
|
||||
norm_elementwise_affine (`bool`, defaults to `False):
|
||||
norm_eps (`bool`, defaults to `False`):
|
||||
chunk_dim (`int`, defaults to `0`):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
output_dim: Optional[int] = None,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if self.emb is not None:
|
||||
temb = self.emb(timestep)
|
||||
|
||||
temb = self.linear(self.silu(temb))
|
||||
|
||||
if self.chunk_dim == 1:
|
||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||
# other if-branch. This branch is specific to CogVideoX for now.
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
else:
|
||||
scale, shift = temb.chunk(2, dim=0)
|
||||
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
@ -1,67 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CogVideoXUpsample3D(nn.Module):
|
||||
r"""
|
||||
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of channels in the input image.
|
||||
out_channels (`int`):
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size (`int`, defaults to `3`):
|
||||
Size of the convolving kernel.
|
||||
stride (`int`, defaults to `1`):
|
||||
Stride of the convolution.
|
||||
padding (`int`, defaults to `1`):
|
||||
Padding added to all four sides of the input.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to compress the time dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: int = 1,
|
||||
compress_time: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||
|
||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||
x_first = x_first[:, :, None, :, :]
|
||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||
elif inputs.shape[2] > 1:
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
else:
|
||||
inputs = inputs.squeeze(2)
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs[:, :, None, :, :]
|
||||
else:
|
||||
# only interpolate 2D
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = self.conv(inputs)
|
||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return inputs
|
||||
@ -1,64 +0,0 @@
|
||||
class PABConfig:
|
||||
def __init__(
|
||||
self,
|
||||
steps: int,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = None,
|
||||
cross_range: int = None,
|
||||
spatial_broadcast: bool = False,
|
||||
spatial_threshold: list = None,
|
||||
spatial_range: int = None,
|
||||
temporal_broadcast: bool = False,
|
||||
temporal_threshold: list = None,
|
||||
temporal_range: int = None,
|
||||
mlp_broadcast: bool = False,
|
||||
mlp_spatial_broadcast_config: dict = None,
|
||||
mlp_temporal_broadcast_config: dict = None,
|
||||
):
|
||||
self.steps = steps
|
||||
|
||||
self.cross_broadcast = cross_broadcast
|
||||
self.cross_threshold = cross_threshold
|
||||
self.cross_range = cross_range
|
||||
|
||||
self.spatial_broadcast = spatial_broadcast
|
||||
self.spatial_threshold = spatial_threshold
|
||||
self.spatial_range = spatial_range
|
||||
|
||||
self.temporal_broadcast = temporal_broadcast
|
||||
self.temporal_threshold = temporal_threshold
|
||||
self.temporal_range = temporal_range
|
||||
|
||||
self.mlp_broadcast = mlp_broadcast
|
||||
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
||||
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
||||
self.mlp_temporal_outputs = {}
|
||||
self.mlp_spatial_outputs = {}
|
||||
|
||||
class CogVideoXPABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 850],
|
||||
spatial_range: int = 2,
|
||||
temporal_broadcast: bool = False,
|
||||
temporal_threshold: list = [100, 850],
|
||||
temporal_range: int = 4,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = [100, 850],
|
||||
cross_range: int = 6,
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
temporal_broadcast=temporal_broadcast,
|
||||
temporal_threshold=temporal_threshold,
|
||||
temporal_range=temporal_range,
|
||||
cross_broadcast=cross_broadcast,
|
||||
cross_threshold=cross_threshold,
|
||||
cross_range=cross_range
|
||||
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user