mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-21 23:42:14 +08:00
code cleanup
codebase getting too bloated: drop PAB support in favor of FasterCache drop temporal tilling in favor of FreeNoise
This commit is contained in:
parent
e8a289112f
commit
0bd3da569e
@ -1,741 +0,0 @@
|
|||||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import glob
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
||||||
from diffusers.utils import is_torch_version, logging
|
|
||||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
||||||
from diffusers.models.attention import Attention, FeedForward
|
|
||||||
#from diffusers.models.attention_processor import AttentionProcessor
|
|
||||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
|
||||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
|
||||||
#from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
|
||||||
|
|
||||||
from ..videosys.modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
|
||||||
from ..videosys.modules.embeddings import apply_rotary_emb
|
|
||||||
from ..videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
try:
|
|
||||||
from sageattention import sageattn
|
|
||||||
SAGEATTN_IS_AVAVILABLE = True
|
|
||||||
logger.info("Using sageattn")
|
|
||||||
except:
|
|
||||||
logger.info("sageattn not found, using sdpa")
|
|
||||||
SAGEATTN_IS_AVAVILABLE = False
|
|
||||||
|
|
||||||
class CogVideoXAttnProcessor2_0:
|
|
||||||
r"""
|
|
||||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
|
||||||
query and key vectors, but does not include spatial normalization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
key = attn.to_k(hidden_states)
|
|
||||||
value = attn.to_v(hidden_states)
|
|
||||||
|
|
||||||
# if attn.parallel_manager.sp_size > 1:
|
|
||||||
# assert (
|
|
||||||
# attn.heads % attn.parallel_manager.sp_size == 0
|
|
||||||
# ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
|
|
||||||
# attn_heads = attn.heads // attn.parallel_manager.sp_size
|
|
||||||
# query, key, value = map(
|
|
||||||
# lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
|
|
||||||
# [query, key, value],
|
|
||||||
# )
|
|
||||||
|
|
||||||
attn_heads = attn.heads
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn_heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
||||||
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if attn.norm_q is not None:
|
|
||||||
query = attn.norm_q(query)
|
|
||||||
if attn.norm_k is not None:
|
|
||||||
key = attn.norm_k(key)
|
|
||||||
|
|
||||||
# Apply RoPE if needed
|
|
||||||
if image_rotary_emb is not None:
|
|
||||||
emb_len = image_rotary_emb[0].shape[0]
|
|
||||||
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
|
||||||
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
|
||||||
)
|
|
||||||
if not attn.is_cross_attention:
|
|
||||||
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
|
||||||
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
|
||||||
)
|
|
||||||
|
|
||||||
if SAGEATTN_IS_AVAVILABLE:
|
|
||||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
|
||||||
else:
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
|
||||||
|
|
||||||
#if attn.parallel_manager.sp_size > 1:
|
|
||||||
# hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
||||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
||||||
)
|
|
||||||
return hidden_states, encoder_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FusedCogVideoXAttnProcessor2_0:
|
|
||||||
r"""
|
|
||||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
|
||||||
query and key vectors, but does not include spatial normalization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
qkv = attn.to_qkv(hidden_states)
|
|
||||||
split_size = qkv.shape[-1] // 3
|
|
||||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if attn.norm_q is not None:
|
|
||||||
query = attn.norm_q(query)
|
|
||||||
if attn.norm_k is not None:
|
|
||||||
key = attn.norm_k(key)
|
|
||||||
|
|
||||||
# Apply RoPE if needed
|
|
||||||
if image_rotary_emb is not None:
|
|
||||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
if not attn.is_cross_attention:
|
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
|
|
||||||
if SAGEATTN_IS_AVAVILABLE:
|
|
||||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
|
||||||
else:
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
||||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
||||||
)
|
|
||||||
return hidden_states, encoder_hidden_states
|
|
||||||
|
|
||||||
class CogVideoXPatchEmbed(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size: int = 2,
|
|
||||||
in_channels: int = 16,
|
|
||||||
embed_dim: int = 1920,
|
|
||||||
text_embed_dim: int = 4096,
|
|
||||||
bias: bool = True,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
self.proj = nn.Conv2d(
|
|
||||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
|
||||||
)
|
|
||||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
text_embeds (`torch.Tensor`):
|
|
||||||
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
|
||||||
image_embeds (`torch.Tensor`):
|
|
||||||
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
|
||||||
"""
|
|
||||||
text_embeds = self.text_proj(text_embeds)
|
|
||||||
|
|
||||||
batch, num_frames, channels, height, width = image_embeds.shape
|
|
||||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
|
||||||
image_embeds = self.proj(image_embeds)
|
|
||||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
|
||||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
|
||||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
|
||||||
|
|
||||||
embeds = torch.cat(
|
|
||||||
[text_embeds, image_embeds], dim=1
|
|
||||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
|
||||||
class CogVideoXBlock(nn.Module):
|
|
||||||
r"""
|
|
||||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
dim (`int`):
|
|
||||||
The number of channels in the input and output.
|
|
||||||
num_attention_heads (`int`):
|
|
||||||
The number of heads to use for multi-head attention.
|
|
||||||
attention_head_dim (`int`):
|
|
||||||
The number of channels in each head.
|
|
||||||
time_embed_dim (`int`):
|
|
||||||
The number of channels in timestep embedding.
|
|
||||||
dropout (`float`, defaults to `0.0`):
|
|
||||||
The dropout probability to use.
|
|
||||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
|
||||||
Activation function to be used in feed-forward.
|
|
||||||
attention_bias (`bool`, defaults to `False`):
|
|
||||||
Whether or not to use bias in attention projection layers.
|
|
||||||
qk_norm (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use normalization after query and key projections in Attention.
|
|
||||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
|
||||||
Whether to use learnable elementwise affine parameters for normalization.
|
|
||||||
norm_eps (`float`, defaults to `1e-5`):
|
|
||||||
Epsilon value for normalization layers.
|
|
||||||
final_dropout (`bool` defaults to `False`):
|
|
||||||
Whether to apply a final dropout after the last feed-forward layer.
|
|
||||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
|
||||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
|
||||||
ff_bias (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use bias in Feed-forward layer.
|
|
||||||
attention_out_bias (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use bias in Attention output projection layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
attention_head_dim: int,
|
|
||||||
time_embed_dim: int,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
activation_fn: str = "gelu-approximate",
|
|
||||||
attention_bias: bool = False,
|
|
||||||
qk_norm: bool = True,
|
|
||||||
norm_elementwise_affine: bool = True,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
final_dropout: bool = True,
|
|
||||||
ff_inner_dim: Optional[int] = None,
|
|
||||||
ff_bias: bool = True,
|
|
||||||
attention_out_bias: bool = True,
|
|
||||||
block_idx: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 1. Self Attention
|
|
||||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
|
||||||
|
|
||||||
self.attn1 = Attention(
|
|
||||||
query_dim=dim,
|
|
||||||
dim_head=attention_head_dim,
|
|
||||||
heads=num_attention_heads,
|
|
||||||
qk_norm="layer_norm" if qk_norm else None,
|
|
||||||
eps=1e-6,
|
|
||||||
bias=attention_bias,
|
|
||||||
out_bias=attention_out_bias,
|
|
||||||
processor=CogVideoXAttnProcessor2_0(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# parallel
|
|
||||||
#self.attn1.parallel_manager = None
|
|
||||||
|
|
||||||
# 2. Feed Forward
|
|
||||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
|
||||||
|
|
||||||
self.ff = FeedForward(
|
|
||||||
dim,
|
|
||||||
dropout=dropout,
|
|
||||||
activation_fn=activation_fn,
|
|
||||||
final_dropout=final_dropout,
|
|
||||||
inner_dim=ff_inner_dim,
|
|
||||||
bias=ff_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pab
|
|
||||||
self.attn_count = 0
|
|
||||||
self.last_attn = None
|
|
||||||
self.block_idx = block_idx
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
temb: torch.Tensor,
|
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
timestep=None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
|
||||||
|
|
||||||
# norm & modulate
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
|
||||||
hidden_states, encoder_hidden_states, temb
|
|
||||||
)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
if enable_pab():
|
|
||||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
|
||||||
if enable_pab() and broadcast_attn:
|
|
||||||
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
|
|
||||||
else:
|
|
||||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
|
||||||
hidden_states=norm_hidden_states,
|
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
)
|
|
||||||
if enable_pab():
|
|
||||||
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
|
||||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
|
||||||
|
|
||||||
# norm & modulate
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
|
||||||
hidden_states, encoder_hidden_states, temb
|
|
||||||
)
|
|
||||||
|
|
||||||
# feed-forward
|
|
||||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
ff_output = self.ff(norm_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
|
||||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
|
||||||
|
|
||||||
return hidden_states, encoder_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
||||||
"""
|
|
||||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
num_attention_heads (`int`, defaults to `30`):
|
|
||||||
The number of heads to use for multi-head attention.
|
|
||||||
attention_head_dim (`int`, defaults to `64`):
|
|
||||||
The number of channels in each head.
|
|
||||||
in_channels (`int`, defaults to `16`):
|
|
||||||
The number of channels in the input.
|
|
||||||
out_channels (`int`, *optional*, defaults to `16`):
|
|
||||||
The number of channels in the output.
|
|
||||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
|
||||||
Whether to flip the sin to cos in the time embedding.
|
|
||||||
time_embed_dim (`int`, defaults to `512`):
|
|
||||||
Output dimension of timestep embeddings.
|
|
||||||
text_embed_dim (`int`, defaults to `4096`):
|
|
||||||
Input dimension of text embeddings from the text encoder.
|
|
||||||
num_layers (`int`, defaults to `30`):
|
|
||||||
The number of layers of Transformer blocks to use.
|
|
||||||
dropout (`float`, defaults to `0.0`):
|
|
||||||
The dropout probability to use.
|
|
||||||
attention_bias (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use bias in the attention projection layers.
|
|
||||||
sample_width (`int`, defaults to `90`):
|
|
||||||
The width of the input latents.
|
|
||||||
sample_height (`int`, defaults to `60`):
|
|
||||||
The height of the input latents.
|
|
||||||
sample_frames (`int`, defaults to `49`):
|
|
||||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
|
||||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
|
||||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
|
||||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
|
||||||
patch_size (`int`, defaults to `2`):
|
|
||||||
The size of the patches to use in the patch embedding layer.
|
|
||||||
temporal_compression_ratio (`int`, defaults to `4`):
|
|
||||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
|
||||||
max_text_seq_length (`int`, defaults to `226`):
|
|
||||||
The maximum sequence length of the input text embeddings.
|
|
||||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
|
||||||
Activation function to use in feed-forward.
|
|
||||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
|
||||||
Activation function to use when generating the timestep embeddings.
|
|
||||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use elementwise affine in normalization layers.
|
|
||||||
norm_eps (`float`, defaults to `1e-5`):
|
|
||||||
The epsilon value to use in normalization layers.
|
|
||||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
|
||||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
|
||||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
|
||||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_supports_gradient_checkpointing = True
|
|
||||||
|
|
||||||
@register_to_config
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_attention_heads: int = 30,
|
|
||||||
attention_head_dim: int = 64,
|
|
||||||
in_channels: int = 16,
|
|
||||||
out_channels: Optional[int] = 16,
|
|
||||||
flip_sin_to_cos: bool = True,
|
|
||||||
freq_shift: int = 0,
|
|
||||||
time_embed_dim: int = 512,
|
|
||||||
text_embed_dim: int = 4096,
|
|
||||||
num_layers: int = 30,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
attention_bias: bool = True,
|
|
||||||
sample_width: int = 90,
|
|
||||||
sample_height: int = 60,
|
|
||||||
sample_frames: int = 49,
|
|
||||||
patch_size: int = 2,
|
|
||||||
temporal_compression_ratio: int = 4,
|
|
||||||
max_text_seq_length: int = 226,
|
|
||||||
activation_fn: str = "gelu-approximate",
|
|
||||||
timestep_activation_fn: str = "silu",
|
|
||||||
norm_elementwise_affine: bool = True,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
spatial_interpolation_scale: float = 1.875,
|
|
||||||
temporal_interpolation_scale: float = 1.0,
|
|
||||||
use_rotary_positional_embeddings: bool = False,
|
|
||||||
add_noise_in_inpaint_model: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
|
||||||
|
|
||||||
post_patch_height = sample_height // patch_size
|
|
||||||
post_patch_width = sample_width // patch_size
|
|
||||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
|
||||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
|
||||||
self.post_patch_height = post_patch_height
|
|
||||||
self.post_patch_width = post_patch_width
|
|
||||||
self.post_time_compression_frames = post_time_compression_frames
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
# 1. Patch embedding
|
|
||||||
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
|
||||||
self.embedding_dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
# 2. 3D positional embeddings
|
|
||||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
|
||||||
inner_dim,
|
|
||||||
(post_patch_width, post_patch_height),
|
|
||||||
post_time_compression_frames,
|
|
||||||
spatial_interpolation_scale,
|
|
||||||
temporal_interpolation_scale,
|
|
||||||
)
|
|
||||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
|
||||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
|
||||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
|
||||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
|
||||||
|
|
||||||
# 3. Time embeddings
|
|
||||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
|
||||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
|
||||||
|
|
||||||
# 4. Define spatio-temporal transformers blocks
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
CogVideoXBlock(
|
|
||||||
dim=inner_dim,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
attention_head_dim=attention_head_dim,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
activation_fn=activation_fn,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
norm_elementwise_affine=norm_elementwise_affine,
|
|
||||||
norm_eps=norm_eps,
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
|
||||||
|
|
||||||
# 5. Output blocks
|
|
||||||
self.norm_out = AdaLayerNorm(
|
|
||||||
embedding_dim=time_embed_dim,
|
|
||||||
output_dim=2 * inner_dim,
|
|
||||||
norm_elementwise_affine=norm_elementwise_affine,
|
|
||||||
norm_eps=norm_eps,
|
|
||||||
chunk_dim=1,
|
|
||||||
)
|
|
||||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
|
||||||
self.gradient_checkpointing = value
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
|
||||||
def fuse_qkv_projections(self):
|
|
||||||
"""
|
|
||||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
|
||||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This API is 🧪 experimental.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
"""
|
|
||||||
self.original_attn_processors = None
|
|
||||||
|
|
||||||
for _, attn_processor in self.attn_processors.items():
|
|
||||||
if "Added" in str(attn_processor.__class__.__name__):
|
|
||||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
|
||||||
|
|
||||||
self.original_attn_processors = self.attn_processors
|
|
||||||
|
|
||||||
for module in self.modules():
|
|
||||||
if isinstance(module, Attention):
|
|
||||||
module.fuse_projections(fuse=True)
|
|
||||||
|
|
||||||
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
|
||||||
def unfuse_qkv_projections(self):
|
|
||||||
"""Disables the fused QKV projection if enabled.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This API is 🧪 experimental.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.original_attn_processors is not None:
|
|
||||||
self.set_attn_processor(self.original_attn_processors)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
timestep: Union[int, float, torch.LongTensor],
|
|
||||||
timestep_cond: Optional[torch.Tensor] = None,
|
|
||||||
inpaint_latents: Optional[torch.Tensor] = None,
|
|
||||||
control_latents: Optional[torch.Tensor] = None,
|
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
return_dict: bool = True,
|
|
||||||
):
|
|
||||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
|
||||||
|
|
||||||
# 1. Time embedding
|
|
||||||
timesteps = timestep
|
|
||||||
t_emb = self.time_proj(timesteps)
|
|
||||||
|
|
||||||
# timesteps does not contain any weights and will always return f32 tensors
|
|
||||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
|
||||||
# there might be better ways to encapsulate this.
|
|
||||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
|
||||||
emb = self.time_embedding(t_emb, timestep_cond)
|
|
||||||
|
|
||||||
# 2. Patch embedding
|
|
||||||
if inpaint_latents is not None:
|
|
||||||
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
|
||||||
if control_latents is not None:
|
|
||||||
hidden_states = torch.concat([hidden_states, control_latents], 2)
|
|
||||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
|
||||||
|
|
||||||
# 3. Position embedding
|
|
||||||
text_seq_length = encoder_hidden_states.shape[1]
|
|
||||||
if not self.config.use_rotary_positional_embeddings:
|
|
||||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
|
||||||
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
|
||||||
pos_embeds = self.pos_embedding
|
|
||||||
emb_size = hidden_states.size()[-1]
|
|
||||||
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
|
|
||||||
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
|
|
||||||
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
|
|
||||||
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
|
|
||||||
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
|
|
||||||
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
|
|
||||||
hidden_states = hidden_states + pos_embeds
|
|
||||||
hidden_states = self.embedding_dropout(hidden_states)
|
|
||||||
|
|
||||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
|
||||||
|
|
||||||
# 4. Transformer blocks
|
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
|
||||||
if self.training and self.gradient_checkpointing:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
||||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states,
|
|
||||||
emb,
|
|
||||||
image_rotary_emb,
|
|
||||||
**ckpt_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, encoder_hidden_states = block(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
temb=emb,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
timestep=timestep,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.config.use_rotary_positional_embeddings:
|
|
||||||
# CogVideoX-2B
|
|
||||||
hidden_states = self.norm_final(hidden_states)
|
|
||||||
else:
|
|
||||||
# CogVideoX-5B
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
hidden_states = self.norm_final(hidden_states)
|
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
|
||||||
|
|
||||||
# 5. Final block
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
|
|
||||||
# 6. Unpatchify
|
|
||||||
p = self.config.patch_size
|
|
||||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
|
||||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (output,)
|
|
||||||
return Transformer2DModelOutput(sample=output)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
|
|
||||||
if subfolder is not None:
|
|
||||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
|
||||||
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
|
||||||
|
|
||||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
|
||||||
if not os.path.isfile(config_file):
|
|
||||||
raise RuntimeError(f"{config_file} does not exist")
|
|
||||||
with open(config_file, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
from diffusers.utils import WEIGHTS_NAME
|
|
||||||
model = cls.from_config(config, **transformer_additional_kwargs)
|
|
||||||
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
|
||||||
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
|
||||||
if os.path.exists(model_file):
|
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
|
||||||
elif os.path.exists(model_file_safetensors):
|
|
||||||
from safetensors.torch import load_file, safe_open
|
|
||||||
state_dict = load_file(model_file_safetensors)
|
|
||||||
else:
|
|
||||||
from safetensors.torch import load_file, safe_open
|
|
||||||
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
|
||||||
state_dict = {}
|
|
||||||
for model_file_safetensors in model_files_safetensors:
|
|
||||||
_state_dict = load_file(model_file_safetensors)
|
|
||||||
for key in _state_dict:
|
|
||||||
state_dict[key] = _state_dict[key]
|
|
||||||
|
|
||||||
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
|
|
||||||
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
|
|
||||||
if len(new_shape) == 5:
|
|
||||||
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
|
||||||
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
|
|
||||||
else:
|
|
||||||
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
|
||||||
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
|
|
||||||
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
|
|
||||||
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
|
||||||
else:
|
|
||||||
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
|
|
||||||
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
|
||||||
|
|
||||||
tmp_state_dict = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
|
||||||
tmp_state_dict[key] = state_dict[key]
|
|
||||||
else:
|
|
||||||
print(key, "Size don't match, skip")
|
|
||||||
state_dict = tmp_state_dict
|
|
||||||
|
|
||||||
m, u = model.load_state_dict(state_dict, strict=False)
|
|
||||||
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
|
||||||
print(m)
|
|
||||||
|
|
||||||
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
|
|
||||||
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
|
|
||||||
|
|
||||||
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
|
||||||
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
|
||||||
|
|
||||||
return model
|
|
||||||
@ -33,10 +33,6 @@ from diffusers.video_processor import VideoProcessor
|
|||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from ..videosys.core.pipeline import VideoSysPipeline
|
|
||||||
from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
|
||||||
from ..videosys.core.pab_mgr import set_pab_manager
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@ -158,7 +154,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
|
|||||||
videos: torch.Tensor
|
videos: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-video generation using CogVideoX.
|
Pipeline for text-to-video generation using CogVideoX.
|
||||||
|
|
||||||
@ -188,7 +184,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
vae: AutoencoderKLCogVideoX,
|
vae: AutoencoderKLCogVideoX,
|
||||||
transformer: CogVideoXTransformer3DModel,
|
transformer: CogVideoXTransformer3DModel,
|
||||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||||
pab_config = None
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -210,9 +205,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if pab_config is not None:
|
|
||||||
set_pab_manager(pab_config)
|
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
|
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
|
||||||
latents=None, freenoise=True, context_size=None, context_overlap=None
|
latents=None, freenoise=True, context_size=None, context_overlap=None
|
||||||
@ -348,16 +340,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
if accepts_generator:
|
if accepts_generator:
|
||||||
extra_step_kwargs["generator"] = generator
|
extra_step_kwargs["generator"] = generator
|
||||||
return extra_step_kwargs
|
return extra_step_kwargs
|
||||||
|
|
||||||
def _gaussian_weights(self, t_tile_length, t_batch_size):
|
|
||||||
from numpy import pi, exp, sqrt
|
|
||||||
|
|
||||||
var = 0.01
|
|
||||||
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
|
||||||
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
|
|
||||||
weights = torch.tensor(t_probs)
|
|
||||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||||
def check_inputs(
|
def check_inputs(
|
||||||
@ -697,24 +679,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
# 8. Denoising loop
|
# 8. Denoising loop
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
# 8.5. Temporal tiling prep
|
if context_schedule is not None:
|
||||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
|
||||||
t_tile_length = context_frames
|
|
||||||
t_tile_overlap = context_overlap
|
|
||||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
|
||||||
use_temporal_tiling = True
|
|
||||||
print("Temporal tiling enabled")
|
|
||||||
elif context_schedule is not None:
|
|
||||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||||
use_temporal_tiling = False
|
|
||||||
use_context_schedule = True
|
use_context_schedule = True
|
||||||
from .context import get_context_scheduler
|
from .context import get_context_scheduler
|
||||||
context = get_context_scheduler(context_schedule)
|
context = get_context_scheduler(context_schedule)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
use_temporal_tiling = False
|
|
||||||
use_context_schedule = False
|
use_context_schedule = False
|
||||||
print("Temporal tiling and context schedule disabled")
|
print(" context schedule disabled")
|
||||||
# 7. Create rotary embeds if required
|
# 7. Create rotary embeds if required
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
@ -735,88 +708,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
if use_context_schedule:
|
||||||
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
|
||||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
|
||||||
# =====================================================
|
|
||||||
grid_ts = 0
|
|
||||||
cur_t = 0
|
|
||||||
while cur_t < latents.shape[1]:
|
|
||||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
|
||||||
grid_ts += 1
|
|
||||||
|
|
||||||
all_t = latents.shape[1]
|
|
||||||
latents_all_list = []
|
|
||||||
# =====================================================
|
|
||||||
|
|
||||||
image_rotary_emb = (
|
|
||||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
for t_i in range(grid_ts):
|
|
||||||
if t_i < grid_ts - 1:
|
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
|
||||||
input_end_t = ofs_t + t_tile_length
|
|
||||||
|
|
||||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
|
||||||
control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :]
|
|
||||||
|
|
||||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
|
||||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
|
||||||
|
|
||||||
#t_input = t[None].to(device)
|
|
||||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
||||||
|
|
||||||
# predict noise model_output
|
|
||||||
noise_pred = self.transformer(
|
|
||||||
hidden_states=latent_model_input_tile,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
timestep=t_input,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
return_dict=False,
|
|
||||||
control_latents=control_latents_tile,
|
|
||||||
)[0]
|
|
||||||
noise_pred = noise_pred.float()
|
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
|
||||||
latents_all_list.append(latents_tile)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
|
||||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
|
||||||
# Add each tile contribution to overall latents
|
|
||||||
for t_i in range(grid_ts):
|
|
||||||
if t_i < grid_ts - 1:
|
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
|
||||||
input_end_t = ofs_t + t_tile_length
|
|
||||||
|
|
||||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
|
||||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
|
||||||
|
|
||||||
latents_all /= contributors
|
|
||||||
|
|
||||||
latents = latents_all
|
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
||||||
progress_bar.update()
|
|
||||||
pbar.update(1)
|
|
||||||
# ==========================================
|
|
||||||
elif use_context_schedule:
|
|
||||||
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|||||||
@ -33,11 +33,6 @@ from diffusers.video_processor import VideoProcessor
|
|||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from ..videosys.core.pipeline import VideoSysPipeline
|
|
||||||
from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
|
||||||
from ..videosys.core.pab_mgr import set_pab_manager
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
@ -206,7 +201,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
|
|||||||
videos: torch.Tensor
|
videos: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-video generation using CogVideoX.
|
Pipeline for text-to-video generation using CogVideoX.
|
||||||
|
|
||||||
@ -236,7 +231,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
vae: AutoencoderKLCogVideoX,
|
vae: AutoencoderKLCogVideoX,
|
||||||
transformer: CogVideoXTransformer3DModel,
|
transformer: CogVideoXTransformer3DModel,
|
||||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||||
pab_config = None
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -258,9 +252,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if pab_config is not None:
|
|
||||||
set_pab_manager(pab_config)
|
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self,
|
self,
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -433,16 +424,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
extra_step_kwargs["generator"] = generator
|
extra_step_kwargs["generator"] = generator
|
||||||
return extra_step_kwargs
|
return extra_step_kwargs
|
||||||
|
|
||||||
def _gaussian_weights(self, t_tile_length, t_batch_size):
|
|
||||||
from numpy import pi, exp, sqrt
|
|
||||||
|
|
||||||
var = 0.01
|
|
||||||
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
|
||||||
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
|
|
||||||
weights = torch.tensor(t_probs)
|
|
||||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||||
def check_inputs(
|
def check_inputs(
|
||||||
self,
|
self,
|
||||||
@ -866,22 +847,14 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
# 7. Create rotary embeds if required
|
# 7. Create rotary embeds if required
|
||||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
if context_schedule is not None:
|
||||||
t_tile_length = context_frames
|
|
||||||
t_tile_overlap = context_overlap
|
|
||||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
|
||||||
use_temporal_tiling = True
|
|
||||||
print("Temporal tiling enabled")
|
|
||||||
elif context_schedule is not None:
|
|
||||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||||
use_temporal_tiling = False
|
|
||||||
use_context_schedule = True
|
use_context_schedule = True
|
||||||
from .context import get_context_scheduler
|
from .context import get_context_scheduler
|
||||||
context = get_context_scheduler(context_schedule)
|
context = get_context_scheduler(context_schedule)
|
||||||
else:
|
else:
|
||||||
use_temporal_tiling = False
|
|
||||||
use_context_schedule = False
|
use_context_schedule = False
|
||||||
print("Temporal tiling and context schedule disabled")
|
print("context schedule disabled")
|
||||||
# 7. Create rotary embeds if required
|
# 7. Create rotary embeds if required
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
@ -915,87 +888,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
if use_context_schedule:
|
||||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
|
||||||
# =====================================================
|
|
||||||
grid_ts = 0
|
|
||||||
cur_t = 0
|
|
||||||
while cur_t < latents.shape[1]:
|
|
||||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
|
||||||
grid_ts += 1
|
|
||||||
|
|
||||||
all_t = latents.shape[1]
|
|
||||||
latents_all_list = []
|
|
||||||
# =====================================================
|
|
||||||
|
|
||||||
image_rotary_emb = (
|
|
||||||
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
|
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
for t_i in range(grid_ts):
|
|
||||||
if t_i < grid_ts - 1:
|
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
|
||||||
input_end_t = ofs_t + t_tile_length
|
|
||||||
|
|
||||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
|
||||||
inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :]
|
|
||||||
|
|
||||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
|
||||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
|
||||||
|
|
||||||
#t_input = t[None].to(device)
|
|
||||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
||||||
|
|
||||||
# predict noise model_output
|
|
||||||
noise_pred = self.transformer(
|
|
||||||
hidden_states=latent_model_input_tile,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
timestep=t_input,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
return_dict=False,
|
|
||||||
inpaint_latents=inpaint_latents_tile,
|
|
||||||
)[0]
|
|
||||||
noise_pred = noise_pred.float()
|
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
|
||||||
latents_all_list.append(latents_tile)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
|
||||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
|
||||||
# Add each tile contribution to overall latents
|
|
||||||
for t_i in range(grid_ts):
|
|
||||||
if t_i < grid_ts - 1:
|
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
|
||||||
input_end_t = ofs_t + t_tile_length
|
|
||||||
|
|
||||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
|
||||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
|
||||||
|
|
||||||
latents_all /= contributors
|
|
||||||
|
|
||||||
latents = latents_all
|
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
||||||
progress_bar.update()
|
|
||||||
pbar.update(1)
|
|
||||||
# ==========================================
|
|
||||||
elif use_context_schedule:
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
@ -1133,18 +1026,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
else:
|
else:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
# if output_type == "numpy":
|
|
||||||
# video = self.decode_latents(latents)
|
|
||||||
# elif not output_type == "latent":
|
|
||||||
# video = self.decode_latents(latents)
|
|
||||||
# video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
|
||||||
# else:
|
|
||||||
# video = latents
|
|
||||||
|
|
||||||
# Offload all models
|
# Offload all models
|
||||||
self.maybe_free_model_hooks()
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
# if not return_dict:
|
|
||||||
# video = torch.from_numpy(video)
|
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
@ -12,15 +12,12 @@ from .pipeline_cogvideox import CogVideoXPipeline
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun
|
from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun
|
||||||
from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFunPAB
|
|
||||||
from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun
|
from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun
|
||||||
|
|
||||||
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
|
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
|
||||||
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
|
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
|
||||||
|
|
||||||
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
from .utils import remove_specific_blocks, log
|
||||||
|
|
||||||
from .utils import check_diffusers_version, remove_specific_blocks, log
|
|
||||||
from comfy.utils import load_torch_file
|
from comfy.utils import load_torch_file
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
@ -95,7 +92,6 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
||||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
|
||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
"lora": ("COGLORA", {"default": None}),
|
"lora": ("COGLORA", {"default": None}),
|
||||||
"compile_args":("COMPILEARGS", ),
|
"compile_args":("COMPILEARGS", ),
|
||||||
@ -111,7 +107,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
|
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
|
||||||
|
|
||||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
|
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
|
||||||
enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
|
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
||||||
attention_mode="sdpa", load_device="main_device"):
|
attention_mode="sdpa", load_device="main_device"):
|
||||||
|
|
||||||
if precision == "fp16" and "1.5" in model:
|
if precision == "fp16" and "1.5" in model:
|
||||||
@ -188,15 +184,9 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
if "Fun" in model:
|
if "Fun" in model:
|
||||||
if pab_config is not None:
|
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
|
||||||
transformer = CogVideoXTransformer3DModelFunPAB.from_pretrained(base_path, subfolder=subfolder)
|
|
||||||
else:
|
|
||||||
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
|
|
||||||
else:
|
else:
|
||||||
if pab_config is not None:
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
||||||
transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder=subfolder)
|
|
||||||
else:
|
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
|
||||||
|
|
||||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||||
|
|
||||||
@ -213,12 +203,12 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if "Fun" in model:
|
if "Fun" in model:
|
||||||
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
if "Pose" in model:
|
if "Pose" in model:
|
||||||
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
|
||||||
else:
|
else:
|
||||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
|
||||||
else:
|
else:
|
||||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||||
if "cogvideox-2b-img2vid" in model:
|
if "cogvideox-2b-img2vid" in model:
|
||||||
pipe.input_with_padding = False
|
pipe.input_with_padding = False
|
||||||
|
|
||||||
@ -296,7 +286,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
backend="nexfort",
|
backend="nexfort",
|
||||||
options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
|
options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
|
||||||
ignores=["vae"],
|
ignores=["vae"],
|
||||||
fuse_qkv_projections=True if pab_config is None else False,
|
fuse_qkv_projections= False,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = {
|
pipeline = {
|
||||||
@ -334,7 +324,6 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
|
||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
#"lora": ("COGLORA", {"default": None}),
|
#"lora": ("COGLORA", {"default": None}),
|
||||||
"compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
"compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
@ -348,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
|
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
|
||||||
pab_config=None, block_edit=None, compile="disabled", attention_mode="sdpa"):
|
block_edit=None, compile="disabled", attention_mode="sdpa"):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -396,10 +385,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
transformer_config["in_channels"] = 32
|
transformer_config["in_channels"] = 32
|
||||||
else:
|
else:
|
||||||
transformer_config["in_channels"] = 33
|
transformer_config["in_channels"] = 33
|
||||||
if pab_config is not None:
|
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
|
||||||
transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config)
|
|
||||||
else:
|
|
||||||
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
|
|
||||||
elif "I2V" in model or "Interpolation" in model:
|
elif "I2V" in model or "Interpolation" in model:
|
||||||
transformer_config["in_channels"] = 32
|
transformer_config["in_channels"] = 32
|
||||||
if "1_5" in model:
|
if "1_5" in model:
|
||||||
@ -409,16 +395,10 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
transformer_config["patch_bias"] = False
|
transformer_config["patch_bias"] = False
|
||||||
transformer_config["sample_height"] = 96
|
transformer_config["sample_height"] = 96
|
||||||
transformer_config["sample_width"] = 170
|
transformer_config["sample_width"] = 170
|
||||||
if pab_config is not None:
|
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||||
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
|
|
||||||
else:
|
|
||||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
|
||||||
else:
|
else:
|
||||||
transformer_config["in_channels"] = 16
|
transformer_config["in_channels"] = 16
|
||||||
if pab_config is not None:
|
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||||
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
|
|
||||||
else:
|
|
||||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
|
||||||
|
|
||||||
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
||||||
if "2b" in model:
|
if "2b" in model:
|
||||||
@ -476,13 +456,13 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
|
||||||
vae.load_state_dict(vae_sd)
|
vae.load_state_dict(vae_sd)
|
||||||
if "Pose" in model:
|
if "Pose" in model:
|
||||||
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
|
||||||
else:
|
else:
|
||||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
|
||||||
else:
|
else:
|
||||||
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
|
||||||
vae.load_state_dict(vae_sd)
|
vae.load_state_dict(vae_sd)
|
||||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||||
|
|
||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|||||||
64
nodes.py
64
nodes.py
@ -44,8 +44,6 @@ from PIL import Image
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
if not "CogVideo" in folder_paths.folder_names_and_paths:
|
if not "CogVideo" in folder_paths.folder_names_and_paths:
|
||||||
@ -53,61 +51,11 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
|
|||||||
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
||||||
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
||||||
|
|
||||||
#PAB
|
|
||||||
from .videosys.pab import CogVideoXPABConfig
|
|
||||||
|
|
||||||
class CogVideoPABConfig:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {
|
|
||||||
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}),
|
|
||||||
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
|
||||||
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
|
||||||
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
|
||||||
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}),
|
|
||||||
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
|
||||||
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
|
||||||
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
|
||||||
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}),
|
|
||||||
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
|
||||||
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
|
||||||
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
|
||||||
|
|
||||||
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("PAB_CONFIG",)
|
|
||||||
RETURN_NAMES = ("pab_config", )
|
|
||||||
FUNCTION = "config"
|
|
||||||
CATEGORY = "CogVideoWrapper"
|
|
||||||
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use"
|
|
||||||
|
|
||||||
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
|
|
||||||
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
|
|
||||||
cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps):
|
|
||||||
|
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
||||||
pab_config = CogVideoXPABConfig(
|
|
||||||
steps=steps,
|
|
||||||
spatial_broadcast=spatial_broadcast,
|
|
||||||
spatial_threshold=[spatial_threshold_end, spatial_threshold_start],
|
|
||||||
spatial_range=spatial_range,
|
|
||||||
temporal_broadcast=temporal_broadcast,
|
|
||||||
temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
|
|
||||||
temporal_range=temporal_range,
|
|
||||||
cross_broadcast=cross_broadcast,
|
|
||||||
cross_threshold=[cross_threshold_end, cross_threshold_start],
|
|
||||||
cross_range=cross_range
|
|
||||||
)
|
|
||||||
|
|
||||||
return (pab_config, )
|
|
||||||
|
|
||||||
class CogVideoContextOptions:
|
class CogVideoContextOptions:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],),
|
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],),
|
||||||
"context_frames": ("INT", {"default": 48, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
|
"context_frames": ("INT", {"default": 48, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
|
||||||
"context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
"context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||||
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||||
@ -1152,9 +1100,6 @@ class CogVideoXFunSampler:
|
|||||||
end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
|
end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
|
||||||
|
|
||||||
# Load Sampler
|
# Load Sampler
|
||||||
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
|
|
||||||
log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
|
|
||||||
scheduler="CogVideoXDDIM"
|
|
||||||
scheduler_config = pipeline["scheduler_config"]
|
scheduler_config = pipeline["scheduler_config"]
|
||||||
if scheduler in scheduler_mapping:
|
if scheduler in scheduler_mapping:
|
||||||
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||||
@ -1282,7 +1227,7 @@ class CogVideoXFunControlSampler:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
|
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
|
||||||
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
|
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0,
|
||||||
samples=None, denoise_strength=1.0, context_options=None):
|
samples=None, denoise_strength=1.0, context_options=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -1306,9 +1251,6 @@ class CogVideoXFunControlSampler:
|
|||||||
|
|
||||||
# Load Sampler
|
# Load Sampler
|
||||||
scheduler_config = pipeline["scheduler_config"]
|
scheduler_config = pipeline["scheduler_config"]
|
||||||
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
|
|
||||||
log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
|
|
||||||
scheduler="CogVideoXDDIM"
|
|
||||||
if scheduler in scheduler_mapping:
|
if scheduler in scheduler_mapping:
|
||||||
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||||
pipe.scheduler = noise_scheduler
|
pipe.scheduler = noise_scheduler
|
||||||
@ -1427,7 +1369,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
|
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
|
||||||
"CogVideoXFunControlSampler": CogVideoXFunControlSampler,
|
"CogVideoXFunControlSampler": CogVideoXFunControlSampler,
|
||||||
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
||||||
"CogVideoPABConfig": CogVideoPABConfig,
|
|
||||||
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
||||||
"CogVideoControlImageEncode": CogVideoControlImageEncode,
|
"CogVideoControlImageEncode": CogVideoControlImageEncode,
|
||||||
"CogVideoContextOptions": CogVideoContextOptions,
|
"CogVideoContextOptions": CogVideoContextOptions,
|
||||||
@ -1450,7 +1391,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
|
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
|
||||||
"CogVideoXFunControlSampler": "CogVideoXFun Control Sampler",
|
"CogVideoXFunControlSampler": "CogVideoXFun Control Sampler",
|
||||||
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
||||||
"CogVideoPABConfig": "CogVideo PABConfig",
|
|
||||||
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
||||||
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
||||||
"CogVideoContextOptions": "CogVideo Context Options",
|
"CogVideoContextOptions": "CogVideo Context Options",
|
||||||
|
|||||||
@ -20,8 +20,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel
|
from diffusers.models import AutoencoderKLCogVideoX
|
||||||
#from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
@ -35,10 +35,6 @@ from comfy.utils import ProgressBar
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
from .videosys.core.pipeline import VideoSysPipeline
|
|
||||||
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
|
||||||
from .videosys.core.pab_mgr import set_pab_manager
|
|
||||||
|
|
||||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||||
tw = tgt_width
|
tw = tgt_width
|
||||||
th = tgt_height
|
th = tgt_height
|
||||||
@ -115,7 +111,7 @@ def retrieve_timesteps(
|
|||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
return timesteps, num_inference_steps
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-video generation using CogVideoX.
|
Pipeline for text-to-video generation using CogVideoX.
|
||||||
|
|
||||||
@ -144,10 +140,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vae: AutoencoderKLCogVideoX,
|
vae: AutoencoderKLCogVideoX,
|
||||||
transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB],
|
transformer: CogVideoXTransformer3DModel,
|
||||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||||
original_mask = None,
|
original_mask = None,
|
||||||
pab_config = None
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -164,9 +159,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
self.video_processor.config.do_resize = False
|
self.video_processor.config.do_resize = False
|
||||||
|
|
||||||
if pab_config is not None:
|
|
||||||
set_pab_manager(pab_config)
|
|
||||||
|
|
||||||
self.input_with_padding = True
|
self.input_with_padding = True
|
||||||
|
|
||||||
|
|
||||||
@ -289,29 +281,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||||
|
|
||||||
return timesteps.to(device), num_inference_steps - t_start
|
return timesteps.to(device), num_inference_steps - t_start
|
||||||
|
|
||||||
def _gaussian_weights(self, t_tile_length, t_batch_size):
|
|
||||||
from numpy import pi, exp, sqrt
|
|
||||||
|
|
||||||
var = 0.01
|
|
||||||
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
|
||||||
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
|
|
||||||
weights = torch.tensor(t_probs)
|
|
||||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
# def fuse_qkv_projections(self) -> None:
|
|
||||||
# r"""Enables fused QKV projections."""
|
|
||||||
# self.fusing_transformer = True
|
|
||||||
# self.transformer.fuse_qkv_projections()
|
|
||||||
|
|
||||||
# def unfuse_qkv_projections(self) -> None:
|
|
||||||
# r"""Disable QKV projection fusion if enabled."""
|
|
||||||
# if not self.fusing_transformer:
|
|
||||||
# logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
|
||||||
# else:
|
|
||||||
# self.transformer.unfuse_qkv_projections()
|
|
||||||
# self.fusing_transformer = False
|
|
||||||
|
|
||||||
def _prepare_rotary_positional_embeddings(
|
def _prepare_rotary_positional_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -365,8 +334,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
height: int = 480,
|
height: int = 480,
|
||||||
width: int = 720,
|
width: int = 720,
|
||||||
num_frames: int = 48,
|
num_frames: int = 48,
|
||||||
t_tile_length: int = 12,
|
|
||||||
t_tile_overlap: int = 4,
|
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
guidance_scale: float = 6,
|
guidance_scale: float = 6,
|
||||||
@ -487,9 +454,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
num_frames += self.additional_frames * self.vae_scale_factor_temporal
|
num_frames += self.additional_frames * self.vae_scale_factor_temporal
|
||||||
|
|
||||||
|
|
||||||
#if latents is None and num_frames == t_tile_length:
|
|
||||||
# num_frames += 1
|
|
||||||
|
|
||||||
if self.original_mask is not None:
|
if self.original_mask is not None:
|
||||||
image_latents = latents
|
image_latents = latents
|
||||||
original_image_latents = image_latents
|
original_image_latents = image_latents
|
||||||
@ -569,23 +533,16 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
# 7. context schedule and temporal tiling
|
# 7. context schedule and temporal tiling
|
||||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
if context_schedule is not None:
|
||||||
t_tile_length = context_frames
|
|
||||||
t_tile_overlap = context_overlap
|
|
||||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
|
||||||
use_temporal_tiling = True
|
|
||||||
logger.info("Temporal tiling enabled")
|
|
||||||
elif context_schedule is not None:
|
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
||||||
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||||
use_temporal_tiling = False
|
|
||||||
use_context_schedule = True
|
use_context_schedule = True
|
||||||
from .cogvideox_fun.context import get_context_scheduler
|
from .cogvideox_fun.context import get_context_scheduler
|
||||||
context = get_context_scheduler(context_schedule)
|
context = get_context_scheduler(context_schedule)
|
||||||
|
#todo ofs embeds?
|
||||||
|
|
||||||
else:
|
else:
|
||||||
use_temporal_tiling = False
|
|
||||||
use_context_schedule = False
|
use_context_schedule = False
|
||||||
logger.info("Temporal tiling and context schedule disabled")
|
logger.info("Temporal tiling and context schedule disabled")
|
||||||
# 7.5. Create rotary embeds if required
|
# 7.5. Create rotary embeds if required
|
||||||
@ -647,100 +604,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
# region context schedule sampling
|
||||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
if use_context_schedule:
|
||||||
# =====================================================
|
|
||||||
grid_ts = 0
|
|
||||||
cur_t = 0
|
|
||||||
while cur_t < latents.shape[1]:
|
|
||||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
|
||||||
grid_ts += 1
|
|
||||||
|
|
||||||
all_t = latents.shape[1]
|
|
||||||
latents_all_list = []
|
|
||||||
# =====================================================
|
|
||||||
|
|
||||||
for t_i in range(grid_ts):
|
|
||||||
if t_i < grid_ts - 1:
|
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
|
||||||
input_end_t = ofs_t + t_tile_length
|
|
||||||
|
|
||||||
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
|
||||||
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
||||||
|
|
||||||
image_rotary_emb = (
|
|
||||||
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
|
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
|
||||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
|
||||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
|
||||||
|
|
||||||
#t_input = t[None].to(device)
|
|
||||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
||||||
|
|
||||||
# predict noise model_output
|
|
||||||
noise_pred = self.transformer(
|
|
||||||
hidden_states=latent_model_input_tile,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
timestep=t_input,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
return_dict=False,
|
|
||||||
)[0]
|
|
||||||
noise_pred = noise_pred.float()
|
|
||||||
|
|
||||||
if self.do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
|
||||||
latents_all_list.append(latents_tile)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
|
||||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
|
||||||
# Add each tile contribution to overall latents
|
|
||||||
for t_i in range(grid_ts):
|
|
||||||
if t_i < grid_ts - 1:
|
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
|
||||||
input_end_t = ofs_t + t_tile_length
|
|
||||||
|
|
||||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
|
||||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
|
||||||
|
|
||||||
latents_all /= contributors
|
|
||||||
|
|
||||||
latents = latents_all
|
|
||||||
#print("latents",latents.shape)
|
|
||||||
# start diff diff
|
|
||||||
if i < len(timesteps) - 1 and self.original_mask is not None:
|
|
||||||
noise_timestep = timesteps[i + 1]
|
|
||||||
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
|
|
||||||
)
|
|
||||||
mask = mask.to(latents)
|
|
||||||
ts_from = timesteps[0]
|
|
||||||
ts_to = timesteps[-1]
|
|
||||||
threshold = (t - ts_to) / (ts_from - ts_to)
|
|
||||||
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
|
|
||||||
latents = image_latent * mask + latents * (1 - mask)
|
|
||||||
# end diff diff
|
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
||||||
progress_bar.update()
|
|
||||||
comfy_pbar.update(1)
|
|
||||||
# ==========================================
|
|
||||||
elif use_context_schedule:
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
counter = torch.zeros_like(latent_model_input)
|
counter = torch.zeros_like(latent_model_input)
|
||||||
@ -858,7 +723,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
|
# region sampling
|
||||||
else:
|
else:
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|||||||
@ -1,621 +0,0 @@
|
|||||||
# Adapted from CogVideo
|
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# References:
|
|
||||||
# CogVideo: https://github.com/THUDM/CogVideo
|
|
||||||
# diffusers: https://github.com/huggingface/diffusers
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
from einops import rearrange
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
||||||
from diffusers.models.attention import Attention, FeedForward
|
|
||||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed
|
|
||||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
|
||||||
from diffusers.utils import is_torch_version
|
|
||||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from .core.pab_mgr import enable_pab, if_broadcast_spatial
|
|
||||||
from .modules.embeddings import apply_rotary_emb
|
|
||||||
|
|
||||||
#from .modules.embeddings import CogVideoXPatchEmbed
|
|
||||||
|
|
||||||
from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
|
||||||
try:
|
|
||||||
from sageattention import sageattn
|
|
||||||
SAGEATTN_IS_AVAVILABLE = True
|
|
||||||
except:
|
|
||||||
SAGEATTN_IS_AVAVILABLE = False
|
|
||||||
|
|
||||||
class CogVideoXAttnProcessor2_0:
|
|
||||||
r"""
|
|
||||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
|
||||||
query and key vectors, but does not include spatial normalization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
key = attn.to_k(hidden_states)
|
|
||||||
value = attn.to_v(hidden_states)
|
|
||||||
|
|
||||||
attn_heads = attn.heads
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn_heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
||||||
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if attn.norm_q is not None:
|
|
||||||
query = attn.norm_q(query)
|
|
||||||
if attn.norm_k is not None:
|
|
||||||
key = attn.norm_k(key)
|
|
||||||
|
|
||||||
# Apply RoPE if needed
|
|
||||||
if image_rotary_emb is not None:
|
|
||||||
emb_len = image_rotary_emb[0].shape[0]
|
|
||||||
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
|
||||||
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
|
||||||
)
|
|
||||||
if not attn.is_cross_attention:
|
|
||||||
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
|
||||||
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
|
||||||
)
|
|
||||||
|
|
||||||
if SAGEATTN_IS_AVAVILABLE:
|
|
||||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
|
||||||
else:
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
||||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
||||||
)
|
|
||||||
return hidden_states, encoder_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FusedCogVideoXAttnProcessor2_0:
|
|
||||||
r"""
|
|
||||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
|
||||||
query and key vectors, but does not include spatial normalization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
qkv = attn.to_qkv(hidden_states)
|
|
||||||
split_size = qkv.shape[-1] // 3
|
|
||||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if attn.norm_q is not None:
|
|
||||||
query = attn.norm_q(query)
|
|
||||||
if attn.norm_k is not None:
|
|
||||||
key = attn.norm_k(key)
|
|
||||||
|
|
||||||
# Apply RoPE if needed
|
|
||||||
if image_rotary_emb is not None:
|
|
||||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
if not attn.is_cross_attention:
|
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
|
|
||||||
if SAGEATTN_IS_AVAVILABLE:
|
|
||||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
|
||||||
else:
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
||||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
||||||
)
|
|
||||||
return hidden_states, encoder_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
|
||||||
class CogVideoXBlock(nn.Module):
|
|
||||||
r"""
|
|
||||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
dim (`int`):
|
|
||||||
The number of channels in the input and output.
|
|
||||||
num_attention_heads (`int`):
|
|
||||||
The number of heads to use for multi-head attention.
|
|
||||||
attention_head_dim (`int`):
|
|
||||||
The number of channels in each head.
|
|
||||||
time_embed_dim (`int`):
|
|
||||||
The number of channels in timestep embedding.
|
|
||||||
dropout (`float`, defaults to `0.0`):
|
|
||||||
The dropout probability to use.
|
|
||||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
|
||||||
Activation function to be used in feed-forward.
|
|
||||||
attention_bias (`bool`, defaults to `False`):
|
|
||||||
Whether or not to use bias in attention projection layers.
|
|
||||||
qk_norm (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use normalization after query and key projections in Attention.
|
|
||||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
|
||||||
Whether to use learnable elementwise affine parameters for normalization.
|
|
||||||
norm_eps (`float`, defaults to `1e-5`):
|
|
||||||
Epsilon value for normalization layers.
|
|
||||||
final_dropout (`bool` defaults to `False`):
|
|
||||||
Whether to apply a final dropout after the last feed-forward layer.
|
|
||||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
|
||||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
|
||||||
ff_bias (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use bias in Feed-forward layer.
|
|
||||||
attention_out_bias (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use bias in Attention output projection layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
attention_head_dim: int,
|
|
||||||
time_embed_dim: int,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
activation_fn: str = "gelu-approximate",
|
|
||||||
attention_bias: bool = False,
|
|
||||||
qk_norm: bool = True,
|
|
||||||
norm_elementwise_affine: bool = True,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
final_dropout: bool = True,
|
|
||||||
ff_inner_dim: Optional[int] = None,
|
|
||||||
ff_bias: bool = True,
|
|
||||||
attention_out_bias: bool = True,
|
|
||||||
block_idx: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 1. Self Attention
|
|
||||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
|
||||||
|
|
||||||
self.attn1 = Attention(
|
|
||||||
query_dim=dim,
|
|
||||||
dim_head=attention_head_dim,
|
|
||||||
heads=num_attention_heads,
|
|
||||||
qk_norm="layer_norm" if qk_norm else None,
|
|
||||||
eps=1e-6,
|
|
||||||
bias=attention_bias,
|
|
||||||
out_bias=attention_out_bias,
|
|
||||||
processor=CogVideoXAttnProcessor2_0(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# parallel
|
|
||||||
#self.attn1.parallel_manager = None
|
|
||||||
|
|
||||||
# 2. Feed Forward
|
|
||||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
|
||||||
|
|
||||||
self.ff = FeedForward(
|
|
||||||
dim,
|
|
||||||
dropout=dropout,
|
|
||||||
activation_fn=activation_fn,
|
|
||||||
final_dropout=final_dropout,
|
|
||||||
inner_dim=ff_inner_dim,
|
|
||||||
bias=ff_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pab
|
|
||||||
self.attn_count = 0
|
|
||||||
self.last_attn = None
|
|
||||||
self.block_idx = block_idx
|
|
||||||
#@torch.compiler.disable()
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
temb: torch.Tensor,
|
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
timestep=None,
|
|
||||||
video_flow_feature: Optional[torch.Tensor] = None,
|
|
||||||
fuser=None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
|
||||||
|
|
||||||
# norm & modulate
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
|
||||||
hidden_states, encoder_hidden_states, temb
|
|
||||||
)
|
|
||||||
# Tora Motion-guidance Fuser
|
|
||||||
if video_flow_feature is not None:
|
|
||||||
H, W = video_flow_feature.shape[-2:]
|
|
||||||
T = norm_hidden_states.shape[1] // H // W
|
|
||||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
|
|
||||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
|
||||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
|
||||||
del h, fuser
|
|
||||||
# attention
|
|
||||||
if enable_pab():
|
|
||||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
|
||||||
if enable_pab() and broadcast_attn:
|
|
||||||
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
|
|
||||||
else:
|
|
||||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
|
||||||
hidden_states=norm_hidden_states,
|
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
)
|
|
||||||
if enable_pab():
|
|
||||||
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
|
||||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
|
||||||
|
|
||||||
# norm & modulate
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
|
||||||
hidden_states, encoder_hidden_states, temb
|
|
||||||
)
|
|
||||||
|
|
||||||
# feed-forward
|
|
||||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
ff_output = self.ff(norm_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
|
||||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
|
||||||
|
|
||||||
return hidden_states, encoder_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
||||||
"""
|
|
||||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
num_attention_heads (`int`, defaults to `30`):
|
|
||||||
The number of heads to use for multi-head attention.
|
|
||||||
attention_head_dim (`int`, defaults to `64`):
|
|
||||||
The number of channels in each head.
|
|
||||||
in_channels (`int`, defaults to `16`):
|
|
||||||
The number of channels in the input.
|
|
||||||
out_channels (`int`, *optional*, defaults to `16`):
|
|
||||||
The number of channels in the output.
|
|
||||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
|
||||||
Whether to flip the sin to cos in the time embedding.
|
|
||||||
time_embed_dim (`int`, defaults to `512`):
|
|
||||||
Output dimension of timestep embeddings.
|
|
||||||
text_embed_dim (`int`, defaults to `4096`):
|
|
||||||
Input dimension of text embeddings from the text encoder.
|
|
||||||
num_layers (`int`, defaults to `30`):
|
|
||||||
The number of layers of Transformer blocks to use.
|
|
||||||
dropout (`float`, defaults to `0.0`):
|
|
||||||
The dropout probability to use.
|
|
||||||
attention_bias (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use bias in the attention projection layers.
|
|
||||||
sample_width (`int`, defaults to `90`):
|
|
||||||
The width of the input latents.
|
|
||||||
sample_height (`int`, defaults to `60`):
|
|
||||||
The height of the input latents.
|
|
||||||
sample_frames (`int`, defaults to `49`):
|
|
||||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
|
||||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
|
||||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
|
||||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
|
||||||
patch_size (`int`, defaults to `2`):
|
|
||||||
The size of the patches to use in the patch embedding layer.
|
|
||||||
temporal_compression_ratio (`int`, defaults to `4`):
|
|
||||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
|
||||||
max_text_seq_length (`int`, defaults to `226`):
|
|
||||||
The maximum sequence length of the input text embeddings.
|
|
||||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
|
||||||
Activation function to use in feed-forward.
|
|
||||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
|
||||||
Activation function to use when generating the timestep embeddings.
|
|
||||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
|
||||||
Whether or not to use elementwise affine in normalization layers.
|
|
||||||
norm_eps (`float`, defaults to `1e-5`):
|
|
||||||
The epsilon value to use in normalization layers.
|
|
||||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
|
||||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
|
||||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
|
||||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_supports_gradient_checkpointing = True
|
|
||||||
|
|
||||||
@register_to_config
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_attention_heads: int = 30,
|
|
||||||
attention_head_dim: int = 64,
|
|
||||||
in_channels: int = 16,
|
|
||||||
out_channels: Optional[int] = 16,
|
|
||||||
flip_sin_to_cos: bool = True,
|
|
||||||
freq_shift: int = 0,
|
|
||||||
time_embed_dim: int = 512,
|
|
||||||
text_embed_dim: int = 4096,
|
|
||||||
num_layers: int = 30,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
attention_bias: bool = True,
|
|
||||||
sample_width: int = 90,
|
|
||||||
sample_height: int = 60,
|
|
||||||
sample_frames: int = 49,
|
|
||||||
patch_size: int = 2,
|
|
||||||
temporal_compression_ratio: int = 4,
|
|
||||||
max_text_seq_length: int = 226,
|
|
||||||
activation_fn: str = "gelu-approximate",
|
|
||||||
timestep_activation_fn: str = "silu",
|
|
||||||
norm_elementwise_affine: bool = True,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
spatial_interpolation_scale: float = 1.875,
|
|
||||||
temporal_interpolation_scale: float = 1.0,
|
|
||||||
use_rotary_positional_embeddings: bool = False,
|
|
||||||
use_learned_positional_embeddings: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
|
||||||
|
|
||||||
post_patch_height = sample_height // patch_size
|
|
||||||
post_patch_width = sample_width // patch_size
|
|
||||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
|
||||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
|
||||||
|
|
||||||
# 1. Patch embedding
|
|
||||||
self.patch_embed = CogVideoXPatchEmbed(
|
|
||||||
patch_size=patch_size,
|
|
||||||
in_channels=in_channels,
|
|
||||||
embed_dim=inner_dim,
|
|
||||||
text_embed_dim=text_embed_dim,
|
|
||||||
bias=True,
|
|
||||||
sample_width=sample_width,
|
|
||||||
sample_height=sample_height,
|
|
||||||
sample_frames=sample_frames,
|
|
||||||
temporal_compression_ratio=temporal_compression_ratio,
|
|
||||||
max_text_seq_length=max_text_seq_length,
|
|
||||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
|
||||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
|
||||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
|
||||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
|
||||||
)
|
|
||||||
self.embedding_dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
# 2. 3D positional embeddings
|
|
||||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
|
||||||
inner_dim,
|
|
||||||
(post_patch_width, post_patch_height),
|
|
||||||
post_time_compression_frames,
|
|
||||||
spatial_interpolation_scale,
|
|
||||||
temporal_interpolation_scale,
|
|
||||||
)
|
|
||||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
|
||||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
|
||||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
|
||||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
|
||||||
|
|
||||||
# 3. Time embeddings
|
|
||||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
|
||||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
|
||||||
|
|
||||||
# 4. Define spatio-temporal transformers blocks
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
CogVideoXBlock(
|
|
||||||
dim=inner_dim,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
attention_head_dim=attention_head_dim,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
activation_fn=activation_fn,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
norm_elementwise_affine=norm_elementwise_affine,
|
|
||||||
norm_eps=norm_eps,
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
|
||||||
|
|
||||||
# 5. Output blocks
|
|
||||||
self.norm_out = AdaLayerNorm(
|
|
||||||
embedding_dim=time_embed_dim,
|
|
||||||
output_dim=2 * inner_dim,
|
|
||||||
norm_elementwise_affine=norm_elementwise_affine,
|
|
||||||
norm_eps=norm_eps,
|
|
||||||
chunk_dim=1,
|
|
||||||
)
|
|
||||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
self.fuser_list = None
|
|
||||||
|
|
||||||
# parallel
|
|
||||||
#self.parallel_manager = None
|
|
||||||
|
|
||||||
# def enable_parallel(self, dp_size, sp_size, enable_cp):
|
|
||||||
# # update cfg parallel
|
|
||||||
# if enable_cp and sp_size % 2 == 0:
|
|
||||||
# sp_size = sp_size // 2
|
|
||||||
# cp_size = 2
|
|
||||||
# else:
|
|
||||||
# cp_size = 1
|
|
||||||
|
|
||||||
# self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
|
|
||||||
|
|
||||||
# for _, module in self.named_modules():
|
|
||||||
# if hasattr(module, "parallel_manager"):
|
|
||||||
# module.parallel_manager = self.parallel_manager
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
|
||||||
self.gradient_checkpointing = value
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
encoder_hidden_states: torch.Tensor,
|
|
||||||
timestep: Union[int, float, torch.LongTensor],
|
|
||||||
timestep_cond: Optional[torch.Tensor] = None,
|
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
return_dict: bool = True,
|
|
||||||
controlnet_states: torch.Tensor = None,
|
|
||||||
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
|
|
||||||
video_flow_features: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# if self.parallel_manager.cp_size > 1:
|
|
||||||
# (
|
|
||||||
# hidden_states,
|
|
||||||
# encoder_hidden_states,
|
|
||||||
# timestep,
|
|
||||||
# timestep_cond,
|
|
||||||
# image_rotary_emb,
|
|
||||||
# ) = batch_func(
|
|
||||||
# partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
|
||||||
# hidden_states,
|
|
||||||
# encoder_hidden_states,
|
|
||||||
# timestep,
|
|
||||||
# timestep_cond,
|
|
||||||
# image_rotary_emb,
|
|
||||||
# )
|
|
||||||
|
|
||||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
|
||||||
|
|
||||||
# 1. Time embedding
|
|
||||||
timesteps = timestep
|
|
||||||
t_emb = self.time_proj(timesteps)
|
|
||||||
|
|
||||||
# timesteps does not contain any weights and will always return f32 tensors
|
|
||||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
|
||||||
# there might be better ways to encapsulate this.
|
|
||||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
|
||||||
emb = self.time_embedding(t_emb, timestep_cond)
|
|
||||||
|
|
||||||
# 2. Patch embedding
|
|
||||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
|
||||||
|
|
||||||
# 3. Position embedding
|
|
||||||
text_seq_length = encoder_hidden_states.shape[1]
|
|
||||||
if not self.config.use_rotary_positional_embeddings:
|
|
||||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
|
||||||
|
|
||||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
|
||||||
hidden_states = hidden_states + pos_embeds
|
|
||||||
hidden_states = self.embedding_dropout(hidden_states)
|
|
||||||
|
|
||||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
|
||||||
|
|
||||||
# if self.parallel_manager.sp_size > 1:
|
|
||||||
# set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
|
||||||
# hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
|
||||||
|
|
||||||
# 4. Transformer blocks
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
|
||||||
hidden_states, encoder_hidden_states = block(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
temb=emb,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
timestep=timesteps if enable_pab() else None,
|
|
||||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
|
||||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
|
||||||
)
|
|
||||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
|
||||||
controlnet_states_block = controlnet_states[i]
|
|
||||||
controlnet_block_weight = 1.0
|
|
||||||
if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
|
|
||||||
controlnet_block_weight = controlnet_weights[i]
|
|
||||||
elif isinstance(controlnet_weights, (float, int)):
|
|
||||||
controlnet_block_weight = controlnet_weights
|
|
||||||
|
|
||||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
|
||||||
|
|
||||||
#if self.parallel_manager.sp_size > 1:
|
|
||||||
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
|
||||||
|
|
||||||
if not self.config.use_rotary_positional_embeddings:
|
|
||||||
# CogVideoX-2B
|
|
||||||
hidden_states = self.norm_final(hidden_states)
|
|
||||||
else:
|
|
||||||
# CogVideoX-5B
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
hidden_states = self.norm_final(hidden_states)
|
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
|
||||||
|
|
||||||
# 5. Final block
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
|
|
||||||
# 6. Unpatchify
|
|
||||||
p = self.config.patch_size
|
|
||||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
|
||||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
|
||||||
|
|
||||||
#if self.parallel_manager.cp_size > 1:
|
|
||||||
# output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (output,)
|
|
||||||
return Transformer2DModelOutput(sample=output)
|
|
||||||
@ -1,232 +0,0 @@
|
|||||||
|
|
||||||
PAB_MANAGER = None
|
|
||||||
|
|
||||||
|
|
||||||
class PABConfig:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
steps: int,
|
|
||||||
cross_broadcast: bool = False,
|
|
||||||
cross_threshold: list = None,
|
|
||||||
cross_range: int = None,
|
|
||||||
spatial_broadcast: bool = False,
|
|
||||||
spatial_threshold: list = None,
|
|
||||||
spatial_range: int = None,
|
|
||||||
temporal_broadcast: bool = False,
|
|
||||||
temporal_threshold: list = None,
|
|
||||||
temporal_range: int = None,
|
|
||||||
mlp_broadcast: bool = False,
|
|
||||||
mlp_spatial_broadcast_config: dict = None,
|
|
||||||
mlp_temporal_broadcast_config: dict = None,
|
|
||||||
):
|
|
||||||
self.steps = steps
|
|
||||||
|
|
||||||
self.cross_broadcast = cross_broadcast
|
|
||||||
self.cross_threshold = cross_threshold
|
|
||||||
self.cross_range = cross_range
|
|
||||||
|
|
||||||
self.spatial_broadcast = spatial_broadcast
|
|
||||||
self.spatial_threshold = spatial_threshold
|
|
||||||
self.spatial_range = spatial_range
|
|
||||||
|
|
||||||
self.temporal_broadcast = temporal_broadcast
|
|
||||||
self.temporal_threshold = temporal_threshold
|
|
||||||
self.temporal_range = temporal_range
|
|
||||||
|
|
||||||
self.mlp_broadcast = mlp_broadcast
|
|
||||||
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
|
||||||
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
|
||||||
self.mlp_temporal_outputs = {}
|
|
||||||
self.mlp_spatial_outputs = {}
|
|
||||||
|
|
||||||
|
|
||||||
class PABManager:
|
|
||||||
def __init__(self, config: PABConfig):
|
|
||||||
self.config: PABConfig = config
|
|
||||||
|
|
||||||
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
|
|
||||||
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
|
|
||||||
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
|
|
||||||
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
|
|
||||||
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
|
|
||||||
print(init_prompt)
|
|
||||||
|
|
||||||
def if_broadcast_cross(self, timestep: int, count: int):
|
|
||||||
if (
|
|
||||||
self.config.cross_broadcast
|
|
||||||
and (timestep is not None)
|
|
||||||
and (count % self.config.cross_range != 0)
|
|
||||||
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
|
|
||||||
):
|
|
||||||
flag = True
|
|
||||||
else:
|
|
||||||
flag = False
|
|
||||||
count = (count + 1) % self.config.steps
|
|
||||||
return flag, count
|
|
||||||
|
|
||||||
def if_broadcast_temporal(self, timestep: int, count: int):
|
|
||||||
if (
|
|
||||||
self.config.temporal_broadcast
|
|
||||||
and (timestep is not None)
|
|
||||||
and (count % self.config.temporal_range != 0)
|
|
||||||
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
|
|
||||||
):
|
|
||||||
flag = True
|
|
||||||
else:
|
|
||||||
flag = False
|
|
||||||
count = (count + 1) % self.config.steps
|
|
||||||
return flag, count
|
|
||||||
|
|
||||||
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
|
|
||||||
if (
|
|
||||||
self.config.spatial_broadcast
|
|
||||||
and (timestep is not None)
|
|
||||||
and (count % self.config.spatial_range != 0)
|
|
||||||
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
|
|
||||||
):
|
|
||||||
flag = True
|
|
||||||
else:
|
|
||||||
flag = False
|
|
||||||
count = (count + 1) % self.config.steps
|
|
||||||
return flag, count
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_t_in_skip_config(all_timesteps, timestep, config):
|
|
||||||
is_t_in_skip_config = False
|
|
||||||
skip_range = None
|
|
||||||
for key in config:
|
|
||||||
if key not in all_timesteps:
|
|
||||||
continue
|
|
||||||
index = all_timesteps.index(key)
|
|
||||||
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
|
|
||||||
if timestep in skip_range:
|
|
||||||
is_t_in_skip_config = True
|
|
||||||
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
|
|
||||||
break
|
|
||||||
return is_t_in_skip_config, skip_range
|
|
||||||
|
|
||||||
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
|
||||||
if not self.config.mlp_broadcast:
|
|
||||||
return False, None, False, None
|
|
||||||
|
|
||||||
if is_temporal:
|
|
||||||
cur_config = self.config.mlp_temporal_broadcast_config
|
|
||||||
else:
|
|
||||||
cur_config = self.config.mlp_spatial_broadcast_config
|
|
||||||
|
|
||||||
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
|
|
||||||
next_flag = False
|
|
||||||
if (
|
|
||||||
self.config.mlp_broadcast
|
|
||||||
and (timestep is not None)
|
|
||||||
and (timestep in cur_config)
|
|
||||||
and (block_idx in cur_config[timestep]["block"])
|
|
||||||
):
|
|
||||||
flag = False
|
|
||||||
next_flag = True
|
|
||||||
count = count + 1
|
|
||||||
elif (
|
|
||||||
self.config.mlp_broadcast
|
|
||||||
and (timestep is not None)
|
|
||||||
and (is_t_in_skip_config)
|
|
||||||
and (block_idx in cur_config[skip_range[0]]["block"])
|
|
||||||
):
|
|
||||||
flag = True
|
|
||||||
count = 0
|
|
||||||
else:
|
|
||||||
flag = False
|
|
||||||
|
|
||||||
return flag, count, next_flag, skip_range
|
|
||||||
|
|
||||||
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
|
|
||||||
if is_temporal:
|
|
||||||
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
|
|
||||||
else:
|
|
||||||
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
|
|
||||||
|
|
||||||
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
|
|
||||||
skip_start_t = skip_range[0]
|
|
||||||
if is_temporal:
|
|
||||||
skip_output = (
|
|
||||||
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
|
|
||||||
if self.config.mlp_temporal_outputs is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
skip_output = (
|
|
||||||
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
|
|
||||||
if self.config.mlp_spatial_outputs is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if skip_output is not None:
|
|
||||||
if timestep == skip_range[-1]:
|
|
||||||
# TODO: save memory
|
|
||||||
if is_temporal:
|
|
||||||
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
|
|
||||||
else:
|
|
||||||
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return skip_output
|
|
||||||
|
|
||||||
def get_spatial_mlp_outputs(self):
|
|
||||||
return self.config.mlp_spatial_outputs
|
|
||||||
|
|
||||||
def get_temporal_mlp_outputs(self):
|
|
||||||
return self.config.mlp_temporal_outputs
|
|
||||||
|
|
||||||
|
|
||||||
def set_pab_manager(config: PABConfig):
|
|
||||||
global PAB_MANAGER
|
|
||||||
PAB_MANAGER = PABManager(config)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_pab():
|
|
||||||
if PAB_MANAGER is None:
|
|
||||||
return False
|
|
||||||
return (
|
|
||||||
PAB_MANAGER.config.cross_broadcast
|
|
||||||
or PAB_MANAGER.config.spatial_broadcast
|
|
||||||
or PAB_MANAGER.config.temporal_broadcast
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_steps(steps: int):
|
|
||||||
if PAB_MANAGER is not None:
|
|
||||||
PAB_MANAGER.config.steps = steps
|
|
||||||
|
|
||||||
|
|
||||||
def if_broadcast_cross(timestep: int, count: int):
|
|
||||||
if not enable_pab():
|
|
||||||
return False, count
|
|
||||||
return PAB_MANAGER.if_broadcast_cross(timestep, count)
|
|
||||||
|
|
||||||
|
|
||||||
def if_broadcast_temporal(timestep: int, count: int):
|
|
||||||
if not enable_pab():
|
|
||||||
return False, count
|
|
||||||
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
|
|
||||||
|
|
||||||
|
|
||||||
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
|
|
||||||
if not enable_pab():
|
|
||||||
return False, count
|
|
||||||
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
|
||||||
|
|
||||||
|
|
||||||
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
|
||||||
if not enable_pab():
|
|
||||||
return False, count
|
|
||||||
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
|
|
||||||
|
|
||||||
|
|
||||||
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
|
|
||||||
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
|
|
||||||
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|
|
||||||
@ -1,44 +0,0 @@
|
|||||||
import inspect
|
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
||||||
|
|
||||||
class VideoSysPipeline(DiffusionPipeline):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set_eval_and_device(device: torch.device, *modules):
|
|
||||||
for module in modules:
|
|
||||||
module.eval()
|
|
||||||
module.to(device)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
In diffusers, it is a convention to call the pipeline object.
|
|
||||||
But in VideoSys, we will use the generate method for better prompt.
|
|
||||||
This is a wrapper for the generate method to support the diffusers usage.
|
|
||||||
"""
|
|
||||||
return self.generate(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_signature_keys(cls, obj):
|
|
||||||
parameters = inspect.signature(obj.__init__).parameters
|
|
||||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
|
||||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
|
||||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
|
||||||
# modify: remove the config module from the expected modules
|
|
||||||
expected_modules = expected_modules - {"config"}
|
|
||||||
|
|
||||||
optional_names = list(optional_parameters)
|
|
||||||
for name in optional_names:
|
|
||||||
if name in cls._optional_components:
|
|
||||||
expected_modules.add(name)
|
|
||||||
optional_parameters.remove(name)
|
|
||||||
|
|
||||||
return expected_modules, optional_parameters
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
|
||||||
@ -1,71 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXDownsample3D(nn.Module):
|
|
||||||
# Todo: Wait for paper relase.
|
|
||||||
r"""
|
|
||||||
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (`int`):
|
|
||||||
Number of channels in the input image.
|
|
||||||
out_channels (`int`):
|
|
||||||
Number of channels produced by the convolution.
|
|
||||||
kernel_size (`int`, defaults to `3`):
|
|
||||||
Size of the convolving kernel.
|
|
||||||
stride (`int`, defaults to `2`):
|
|
||||||
Stride of the convolution.
|
|
||||||
padding (`int`, defaults to `0`):
|
|
||||||
Padding added to all four sides of the input.
|
|
||||||
compress_time (`bool`, defaults to `False`):
|
|
||||||
Whether or not to compress the time dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
stride: int = 2,
|
|
||||||
padding: int = 0,
|
|
||||||
compress_time: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.compress_time = compress_time
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.compress_time:
|
|
||||||
batch_size, channels, frames, height, width = x.shape
|
|
||||||
|
|
||||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
|
||||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
|
||||||
|
|
||||||
if x.shape[-1] % 2 == 1:
|
|
||||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
|
||||||
if x_rest.shape[-1] > 0:
|
|
||||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
|
||||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
|
||||||
|
|
||||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
|
||||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
|
||||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
|
||||||
else:
|
|
||||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
|
||||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
|
||||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
|
||||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
|
||||||
|
|
||||||
# Pad the tensor
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = F.pad(x, pad, mode="constant", value=0)
|
|
||||||
batch_size, channels, frames, height, width = x.shape
|
|
||||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
|
||||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
|
||||||
x = self.conv(x)
|
|
||||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
|
||||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
|
||||||
return x
|
|
||||||
@ -1,308 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
class CogVideoXPatchEmbed(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size: int = 2,
|
|
||||||
in_channels: int = 16,
|
|
||||||
embed_dim: int = 1920,
|
|
||||||
text_embed_dim: int = 4096,
|
|
||||||
bias: bool = True,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
self.proj = nn.Conv2d(
|
|
||||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
|
||||||
)
|
|
||||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
text_embeds (`torch.Tensor`):
|
|
||||||
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
|
||||||
image_embeds (`torch.Tensor`):
|
|
||||||
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
|
||||||
"""
|
|
||||||
text_embeds = self.text_proj(text_embeds)
|
|
||||||
|
|
||||||
batch, num_frames, channels, height, width = image_embeds.shape
|
|
||||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
|
||||||
image_embeds = self.proj(image_embeds)
|
|
||||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
|
||||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
|
||||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
|
||||||
|
|
||||||
embeds = torch.cat(
|
|
||||||
[text_embeds, image_embeds], dim=1
|
|
||||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
|
|
||||||
class OpenSoraPatchEmbed3D(nn.Module):
|
|
||||||
"""Video to Patch Embedding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
patch_size (int): Patch token size. Default: (2,4,4).
|
|
||||||
in_chans (int): Number of input video channels. Default: 3.
|
|
||||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
|
||||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size=(2, 4, 4),
|
|
||||||
in_chans=3,
|
|
||||||
embed_dim=96,
|
|
||||||
norm_layer=None,
|
|
||||||
flatten=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.flatten = flatten
|
|
||||||
|
|
||||||
self.in_chans = in_chans
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
||||||
if norm_layer is not None:
|
|
||||||
self.norm = norm_layer(embed_dim)
|
|
||||||
else:
|
|
||||||
self.norm = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward function."""
|
|
||||||
# padding
|
|
||||||
_, _, D, H, W = x.size()
|
|
||||||
if W % self.patch_size[2] != 0:
|
|
||||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
|
||||||
if H % self.patch_size[1] != 0:
|
|
||||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
|
||||||
if D % self.patch_size[0] != 0:
|
|
||||||
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
|
||||||
|
|
||||||
x = self.proj(x) # (B C T H W)
|
|
||||||
if self.norm is not None:
|
|
||||||
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
x = self.norm(x)
|
|
||||||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
|
||||||
if self.flatten:
|
|
||||||
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(nn.Module):
|
|
||||||
"""
|
|
||||||
Embeds scalar timesteps into vector representations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def timestep_embedding(t, dim, max_period=10000):
|
|
||||||
"""
|
|
||||||
Create sinusoidal timestep embeddings.
|
|
||||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
||||||
These may be fractional.
|
|
||||||
:param dim: the dimension of the output.
|
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
|
||||||
:return: an (N, D) Tensor of positional embeddings.
|
|
||||||
"""
|
|
||||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
|
||||||
freqs = freqs.to(device=t.device)
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def forward(self, t, dtype):
|
|
||||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
|
||||||
if t_freq.dtype != dtype:
|
|
||||||
t_freq = t_freq.to(dtype)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
|
||||||
|
|
||||||
|
|
||||||
class SizeEmbedder(TimestepEmbedder):
|
|
||||||
"""
|
|
||||||
Embeds scalar timesteps into vector representations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
|
||||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
self.outdim = hidden_size
|
|
||||||
|
|
||||||
def forward(self, s, bs):
|
|
||||||
if s.ndim == 1:
|
|
||||||
s = s[:, None]
|
|
||||||
assert s.ndim == 2
|
|
||||||
if s.shape[0] != bs:
|
|
||||||
s = s.repeat(bs // s.shape[0], 1)
|
|
||||||
assert s.shape[0] == bs
|
|
||||||
b, dims = s.shape[0], s.shape[1]
|
|
||||||
s = rearrange(s, "b d -> (b d)")
|
|
||||||
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
|
|
||||||
s_emb = self.mlp(s_freq)
|
|
||||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
|
||||||
return s_emb
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return next(self.parameters()).dtype
|
|
||||||
|
|
||||||
def get_3d_rotary_pos_embed(
|
|
||||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
RoPE for video tokens with 3D structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embed_dim: (`int`):
|
|
||||||
The embedding dimension size, corresponding to hidden_size_head.
|
|
||||||
crops_coords (`Tuple[int]`):
|
|
||||||
The top-left and bottom-right coordinates of the crop.
|
|
||||||
grid_size (`Tuple[int]`):
|
|
||||||
The grid size of the spatial positional embedding (height, width).
|
|
||||||
temporal_size (`int`):
|
|
||||||
The size of the temporal dimension.
|
|
||||||
theta (`float`):
|
|
||||||
Scaling factor for frequency computation.
|
|
||||||
use_real (`bool`):
|
|
||||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
|
||||||
"""
|
|
||||||
start, stop = crops_coords
|
|
||||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
|
||||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
|
||||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
|
||||||
|
|
||||||
# Compute dimensions for each axis
|
|
||||||
dim_t = embed_dim // 4
|
|
||||||
dim_h = embed_dim // 8 * 3
|
|
||||||
dim_w = embed_dim // 8 * 3
|
|
||||||
|
|
||||||
# Temporal frequencies
|
|
||||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
|
||||||
grid_t = torch.from_numpy(grid_t).float()
|
|
||||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
|
||||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Spatial frequencies for height and width
|
|
||||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
|
||||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
|
||||||
grid_h = torch.from_numpy(grid_h).float()
|
|
||||||
grid_w = torch.from_numpy(grid_w).float()
|
|
||||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
|
||||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
|
||||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
|
||||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Broadcast and concatenate tensors along specified dimension
|
|
||||||
def broadcast(tensors, dim=-1):
|
|
||||||
num_tensors = len(tensors)
|
|
||||||
shape_lens = {len(t.shape) for t in tensors}
|
|
||||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
|
||||||
shape_len = list(shape_lens)[0]
|
|
||||||
dim = (dim + shape_len) if dim < 0 else dim
|
|
||||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
|
||||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
|
||||||
assert all(
|
|
||||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
|
||||||
), "invalid dimensions for broadcastable concatenation"
|
|
||||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
|
||||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
|
||||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
|
||||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
|
||||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
|
||||||
return torch.cat(tensors, dim=dim)
|
|
||||||
|
|
||||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
|
||||||
|
|
||||||
t, h, w, d = freqs.shape
|
|
||||||
freqs = freqs.view(t * h * w, d)
|
|
||||||
|
|
||||||
# Generate sine and cosine components
|
|
||||||
sin = freqs.sin()
|
|
||||||
cos = freqs.cos()
|
|
||||||
|
|
||||||
if use_real:
|
|
||||||
return cos, sin
|
|
||||||
else:
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
|
||||||
x: torch.Tensor,
|
|
||||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
|
||||||
use_real: bool = True,
|
|
||||||
use_real_unbind_dim: int = -1,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
|
||||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
|
||||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
|
||||||
tensors contain rotary embeddings and are returned as real tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (`torch.Tensor`):
|
|
||||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
|
||||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
|
||||||
"""
|
|
||||||
if use_real:
|
|
||||||
cos, sin = freqs_cis # [S, D]
|
|
||||||
cos = cos[None, None]
|
|
||||||
sin = sin[None, None]
|
|
||||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
|
||||||
|
|
||||||
if use_real_unbind_dim == -1:
|
|
||||||
# Use for example in Lumina
|
|
||||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
|
||||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
elif use_real_unbind_dim == -2:
|
|
||||||
# Use for example in Stable Audio
|
|
||||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
|
||||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
|
||||||
|
|
||||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
|
||||||
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
||||||
freqs_cis = freqs_cis.unsqueeze(2)
|
|
||||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
|
||||||
|
|
||||||
return x_out.type_as(x)
|
|
||||||
@ -1,85 +0,0 @@
|
|||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXLayerNormZero(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
conditioning_dim: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
elementwise_affine: bool = True,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
bias: bool = True,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
|
||||||
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
|
||||||
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
|
||||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
|
||||||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNorm(nn.Module):
|
|
||||||
r"""
|
|
||||||
Norm layer modified to incorporate timestep embeddings.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
embedding_dim (`int`): The size of each embedding vector.
|
|
||||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
|
||||||
output_dim (`int`, *optional*):
|
|
||||||
norm_elementwise_affine (`bool`, defaults to `False):
|
|
||||||
norm_eps (`bool`, defaults to `False`):
|
|
||||||
chunk_dim (`int`, defaults to `0`):
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_dim: int,
|
|
||||||
num_embeddings: Optional[int] = None,
|
|
||||||
output_dim: Optional[int] = None,
|
|
||||||
norm_elementwise_affine: bool = False,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
chunk_dim: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.chunk_dim = chunk_dim
|
|
||||||
output_dim = output_dim or embedding_dim * 2
|
|
||||||
|
|
||||||
if num_embeddings is not None:
|
|
||||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
|
||||||
else:
|
|
||||||
self.emb = None
|
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
|
||||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
|
||||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if self.emb is not None:
|
|
||||||
temb = self.emb(timestep)
|
|
||||||
|
|
||||||
temb = self.linear(self.silu(temb))
|
|
||||||
|
|
||||||
if self.chunk_dim == 1:
|
|
||||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
|
||||||
# other if-branch. This branch is specific to CogVideoX for now.
|
|
||||||
shift, scale = temb.chunk(2, dim=1)
|
|
||||||
shift = shift[:, None, :]
|
|
||||||
scale = scale[:, None, :]
|
|
||||||
else:
|
|
||||||
scale, shift = temb.chunk(2, dim=0)
|
|
||||||
|
|
||||||
x = self.norm(x) * (1 + scale) + shift
|
|
||||||
return x
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXUpsample3D(nn.Module):
|
|
||||||
r"""
|
|
||||||
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (`int`):
|
|
||||||
Number of channels in the input image.
|
|
||||||
out_channels (`int`):
|
|
||||||
Number of channels produced by the convolution.
|
|
||||||
kernel_size (`int`, defaults to `3`):
|
|
||||||
Size of the convolving kernel.
|
|
||||||
stride (`int`, defaults to `1`):
|
|
||||||
Stride of the convolution.
|
|
||||||
padding (`int`, defaults to `1`):
|
|
||||||
Padding added to all four sides of the input.
|
|
||||||
compress_time (`bool`, defaults to `False`):
|
|
||||||
Whether or not to compress the time dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
stride: int = 1,
|
|
||||||
padding: int = 1,
|
|
||||||
compress_time: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.compress_time = compress_time
|
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.compress_time:
|
|
||||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
|
||||||
# split first frame
|
|
||||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
|
||||||
|
|
||||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
|
||||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
|
||||||
x_first = x_first[:, :, None, :, :]
|
|
||||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
|
||||||
elif inputs.shape[2] > 1:
|
|
||||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
|
||||||
else:
|
|
||||||
inputs = inputs.squeeze(2)
|
|
||||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
|
||||||
inputs = inputs[:, :, None, :, :]
|
|
||||||
else:
|
|
||||||
# only interpolate 2D
|
|
||||||
b, c, t, h, w = inputs.shape
|
|
||||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
||||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
|
||||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
b, c, t, h, w = inputs.shape
|
|
||||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
||||||
inputs = self.conv(inputs)
|
|
||||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
@ -1,64 +0,0 @@
|
|||||||
class PABConfig:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
steps: int,
|
|
||||||
cross_broadcast: bool = False,
|
|
||||||
cross_threshold: list = None,
|
|
||||||
cross_range: int = None,
|
|
||||||
spatial_broadcast: bool = False,
|
|
||||||
spatial_threshold: list = None,
|
|
||||||
spatial_range: int = None,
|
|
||||||
temporal_broadcast: bool = False,
|
|
||||||
temporal_threshold: list = None,
|
|
||||||
temporal_range: int = None,
|
|
||||||
mlp_broadcast: bool = False,
|
|
||||||
mlp_spatial_broadcast_config: dict = None,
|
|
||||||
mlp_temporal_broadcast_config: dict = None,
|
|
||||||
):
|
|
||||||
self.steps = steps
|
|
||||||
|
|
||||||
self.cross_broadcast = cross_broadcast
|
|
||||||
self.cross_threshold = cross_threshold
|
|
||||||
self.cross_range = cross_range
|
|
||||||
|
|
||||||
self.spatial_broadcast = spatial_broadcast
|
|
||||||
self.spatial_threshold = spatial_threshold
|
|
||||||
self.spatial_range = spatial_range
|
|
||||||
|
|
||||||
self.temporal_broadcast = temporal_broadcast
|
|
||||||
self.temporal_threshold = temporal_threshold
|
|
||||||
self.temporal_range = temporal_range
|
|
||||||
|
|
||||||
self.mlp_broadcast = mlp_broadcast
|
|
||||||
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
|
||||||
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
|
||||||
self.mlp_temporal_outputs = {}
|
|
||||||
self.mlp_spatial_outputs = {}
|
|
||||||
|
|
||||||
class CogVideoXPABConfig(PABConfig):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
steps: int = 50,
|
|
||||||
spatial_broadcast: bool = True,
|
|
||||||
spatial_threshold: list = [100, 850],
|
|
||||||
spatial_range: int = 2,
|
|
||||||
temporal_broadcast: bool = False,
|
|
||||||
temporal_threshold: list = [100, 850],
|
|
||||||
temporal_range: int = 4,
|
|
||||||
cross_broadcast: bool = False,
|
|
||||||
cross_threshold: list = [100, 850],
|
|
||||||
cross_range: int = 6,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
steps=steps,
|
|
||||||
spatial_broadcast=spatial_broadcast,
|
|
||||||
spatial_threshold=spatial_threshold,
|
|
||||||
spatial_range=spatial_range,
|
|
||||||
temporal_broadcast=temporal_broadcast,
|
|
||||||
temporal_threshold=temporal_threshold,
|
|
||||||
temporal_range=temporal_range,
|
|
||||||
cross_broadcast=cross_broadcast,
|
|
||||||
cross_threshold=cross_threshold,
|
|
||||||
cross_range=cross_range
|
|
||||||
|
|
||||||
)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user