code cleanup

codebase getting too bloated:
drop PAB support in favor of FasterCache
drop temporal tilling in favor of FreeNoise
This commit is contained in:
kijai 2024-11-14 19:54:52 +02:00
parent e8a289112f
commit 0bd3da569e
17 changed files with 35 additions and 2711 deletions

View File

@ -1,741 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import os
import json
import torch
import glob
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import Attention, FeedForward
#from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
#from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from ..videosys.modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from ..videosys.modules.embeddings import apply_rotary_emb
from ..videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True
logger.info("Using sageattn")
except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAVILABLE = False
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# if attn.parallel_manager.sp_size > 1:
# assert (
# attn.heads % attn.parallel_manager.sp_size == 0
# ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
# attn_heads = attn.heads // attn.parallel_manager.sp_size
# query, key, value = map(
# lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
# [query, key, value],
# )
attn_heads = attn.heads
inner_dim = key.shape[-1]
head_dim = inner_dim // attn_heads
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
emb_len = image_rotary_emb[0].shape[0]
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
#if attn.parallel_manager.sp_size > 1:
# hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
block_idx: int = 0,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# parallel
#self.attn1.parallel_manager = None
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# pab
self.attn_count = 0
self.last_attn = None
self.block_idx = block_idx
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep=None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
if enable_pab():
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
if enable_pab() and broadcast_attn:
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if enable_pab():
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
add_noise_in_inpaint_model: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
self.post_patch_height = post_patch_height
self.post_patch_width = post_patch_width
self.post_time_compression_frames = post_time_compression_frames
self.patch_size = patch_size
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 2)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
pos_embeds = self.pos_embedding
emb_size = hidden_states.size()[-1]
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
timestep=timestep,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **transformer_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for model_file_safetensors in model_files_safetensors:
_state_dict = load_file(model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
else:
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
else:
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
print(m)
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
return model

View File

@ -33,10 +33,6 @@ from diffusers.video_processor import VideoProcessor
from diffusers.image_processor import VaeImageProcessor
from einops import rearrange
from ..videosys.core.pipeline import VideoSysPipeline
from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
from ..videosys.core.pab_mgr import set_pab_manager
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -158,7 +154,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
videos: torch.Tensor
class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
@ -188,7 +184,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
pab_config = None
):
super().__init__()
@ -210,9 +205,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
if pab_config is not None:
set_pab_manager(pab_config)
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
latents=None, freenoise=True, context_size=None, context_overlap=None
@ -348,16 +340,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def _gaussian_weights(self, t_tile_length, t_batch_size):
from numpy import pi, exp, sqrt
var = 0.01
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
weights = torch.tensor(t_probs)
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
@ -697,24 +679,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.5. Temporal tiling prep
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames
t_tile_overlap = context_overlap
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
use_temporal_tiling = True
print("Temporal tiling enabled")
elif context_schedule is not None:
if context_schedule is not None:
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
else:
use_temporal_tiling = False
use_context_schedule = False
print("Temporal tiling and context schedule disabled")
print(" context schedule disabled")
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
@ -735,88 +708,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# =====================================================
grid_ts = 0
cur_t = 0
while cur_t < latents.shape[1]:
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
grid_ts += 1
all_t = latents.shape[1]
latents_all_list = []
# =====================================================
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :]
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
#t_input = t[None].to(device)
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input_tile,
encoder_hidden_states=prompt_embeds,
timestep=t_input,
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=control_latents_tile,
)[0]
noise_pred = noise_pred.float()
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents_all_list.append(latents_tile)
# ==========================================
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
# Add each tile contribution to overall latents
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
latents_all /= contributors
latents = latents_all
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
pbar.update(1)
# ==========================================
elif use_context_schedule:
if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@ -33,11 +33,6 @@ from diffusers.video_processor import VideoProcessor
from diffusers.image_processor import VaeImageProcessor
from einops import rearrange
from ..videosys.core.pipeline import VideoSysPipeline
from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
from ..videosys.core.pab_mgr import set_pab_manager
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -206,7 +201,7 @@ class CogVideoX_Fun_PipelineOutput(BaseOutput):
videos: torch.Tensor
class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
@ -236,7 +231,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
pab_config = None
):
super().__init__()
@ -258,9 +252,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
if pab_config is not None:
set_pab_manager(pab_config)
def prepare_latents(
self,
batch_size,
@ -433,16 +424,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def _gaussian_weights(self, t_tile_length, t_batch_size):
from numpy import pi, exp, sqrt
var = 0.01
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
weights = torch.tensor(t_probs)
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
@ -866,22 +847,14 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames
t_tile_overlap = context_overlap
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
use_temporal_tiling = True
print("Temporal tiling enabled")
elif context_schedule is not None:
if context_schedule is not None:
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
else:
use_temporal_tiling = False
use_context_schedule = False
print("Temporal tiling and context schedule disabled")
print("context schedule disabled")
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
@ -915,87 +888,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
if self.interrupt:
continue
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# =====================================================
grid_ts = 0
cur_t = 0
while cur_t < latents.shape[1]:
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
grid_ts += 1
all_t = latents.shape[1]
latents_all_list = []
# =====================================================
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :]
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
#t_input = t[None].to(device)
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input_tile,
encoder_hidden_states=prompt_embeds,
timestep=t_input,
image_rotary_emb=image_rotary_emb,
return_dict=False,
inpaint_latents=inpaint_latents_tile,
)[0]
noise_pred = noise_pred.float()
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents_all_list.append(latents_tile)
# ==========================================
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
# Add each tile contribution to overall latents
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
latents_all /= contributors
latents = latents_all
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
pbar.update(1)
# ==========================================
elif use_context_schedule:
if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@ -1133,18 +1026,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
else:
pbar.update(1)
# if output_type == "numpy":
# video = self.decode_latents(latents)
# elif not output_type == "latent":
# video = self.decode_latents(latents)
# video = self.video_processor.postprocess_video(video=video, output_type=output_type)
# else:
# video = latents
# Offload all models
self.maybe_free_model_hooks()
# if not return_dict:
# video = torch.from_numpy(video)
return latents

View File

@ -12,15 +12,12 @@ from .pipeline_cogvideox import CogVideoXPipeline
from contextlib import nullcontext
from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun
from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFunPAB
from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
from .utils import check_diffusers_version, remove_specific_blocks, log
from .utils import remove_specific_blocks, log
from comfy.utils import load_torch_file
script_directory = os.path.dirname(os.path.abspath(__file__))
@ -95,7 +92,6 @@ class DownloadAndLoadCogVideoModel:
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
@ -111,7 +107,7 @@ class DownloadAndLoadCogVideoModel:
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
if precision == "fp16" and "1.5" in model:
@ -188,15 +184,9 @@ class DownloadAndLoadCogVideoModel:
# transformer
if "Fun" in model:
if pab_config is not None:
transformer = CogVideoXTransformer3DModelFunPAB.from_pretrained(base_path, subfolder=subfolder)
else:
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
else:
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder=subfolder)
else:
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = transformer.to(dtype).to(transformer_load_device)
@ -213,12 +203,12 @@ class DownloadAndLoadCogVideoModel:
if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
if "Pose" in model:
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
else:
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
pipe = CogVideoXPipeline(vae, transformer, scheduler)
if "cogvideox-2b-img2vid" in model:
pipe.input_with_padding = False
@ -296,7 +286,7 @@ class DownloadAndLoadCogVideoModel:
backend="nexfort",
options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
ignores=["vae"],
fuse_qkv_projections=True if pab_config is None else False,
fuse_qkv_projections= False,
)
pipeline = {
@ -334,7 +324,6 @@ class DownloadAndLoadCogVideoGGUFModel:
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
},
"optional": {
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
#"lora": ("COGLORA", {"default": None}),
"compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
@ -348,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel:
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
pab_config=None, block_edit=None, compile="disabled", attention_mode="sdpa"):
block_edit=None, compile="disabled", attention_mode="sdpa"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -396,10 +385,7 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["in_channels"] = 32
else:
transformer_config["in_channels"] = 33
if pab_config is not None:
transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config)
else:
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
elif "I2V" in model or "Interpolation" in model:
transformer_config["in_channels"] = 32
if "1_5" in model:
@ -409,16 +395,10 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["patch_bias"] = False
transformer_config["sample_height"] = 96
transformer_config["sample_width"] = 170
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
else:
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
else:
transformer_config["in_channels"] = 16
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
else:
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
if "2b" in model:
@ -476,13 +456,13 @@ class DownloadAndLoadCogVideoGGUFModel:
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
if "Pose" in model:
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler)
else:
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
else:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
pipe = CogVideoXPipeline(vae, transformer, scheduler)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()

View File

@ -44,8 +44,6 @@ from PIL import Image
import numpy as np
import json
script_directory = os.path.dirname(os.path.abspath(__file__))
if not "CogVideo" in folder_paths.folder_names_and_paths:
@ -53,61 +51,11 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
#PAB
from .videosys.pab import CogVideoXPABConfig
class CogVideoPABConfig:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}),
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}),
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}),
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ),
}
}
RETURN_TYPES = ("PAB_CONFIG",)
RETURN_NAMES = ("pab_config", )
FUNCTION = "config"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use"
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
pab_config = CogVideoXPABConfig(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=[spatial_threshold_end, spatial_threshold_start],
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=[cross_threshold_end, cross_threshold_start],
cross_range=cross_range
)
return (pab_config, )
class CogVideoContextOptions:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],),
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],),
"context_frames": ("INT", {"default": 48, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
"context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
@ -1152,9 +1100,6 @@ class CogVideoXFunSampler:
end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
# Load Sampler
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
scheduler="CogVideoXDDIM"
scheduler_config = pipeline["scheduler_config"]
if scheduler in scheduler_mapping:
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
@ -1282,7 +1227,7 @@ class CogVideoXFunControlSampler:
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0,
samples=None, denoise_strength=1.0, context_options=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -1306,9 +1251,6 @@ class CogVideoXFunControlSampler:
# Load Sampler
scheduler_config = pipeline["scheduler_config"]
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
log.info("Temporal tiling enabled, changing scheduler to CogVideoXDDIM")
scheduler="CogVideoXDDIM"
if scheduler in scheduler_mapping:
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
pipe.scheduler = noise_scheduler
@ -1427,7 +1369,6 @@ NODE_CLASS_MAPPINGS = {
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
"CogVideoXFunControlSampler": CogVideoXFunControlSampler,
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
"CogVideoPABConfig": CogVideoPABConfig,
"CogVideoTransformerEdit": CogVideoTransformerEdit,
"CogVideoControlImageEncode": CogVideoControlImageEncode,
"CogVideoContextOptions": CogVideoContextOptions,
@ -1450,7 +1391,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
"CogVideoXFunControlSampler": "CogVideoXFun Control Sampler",
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
"CogVideoPABConfig": "CogVideo PABConfig",
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
"CogVideoContextOptions": "CogVideo Context Options",

View File

@ -20,8 +20,8 @@ import torch
import torch.nn.functional as F
import math
from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel
#from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.models import AutoencoderKLCogVideoX
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
@ -35,10 +35,6 @@ from comfy.utils import ProgressBar
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from .videosys.core.pipeline import VideoSysPipeline
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
from .videosys.core.pab_mgr import set_pab_manager
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
@ -115,7 +111,7 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
@ -144,10 +140,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
def __init__(
self,
vae: AutoencoderKLCogVideoX,
transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB],
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
original_mask = None,
pab_config = None
):
super().__init__()
@ -164,9 +159,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.video_processor.config.do_resize = False
if pab_config is not None:
set_pab_manager(pab_config)
self.input_with_padding = True
@ -289,29 +281,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps.to(device), num_inference_steps - t_start
def _gaussian_weights(self, t_tile_length, t_batch_size):
from numpy import pi, exp, sqrt
var = 0.01
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
weights = torch.tensor(t_probs)
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights
# def fuse_qkv_projections(self) -> None:
# r"""Enables fused QKV projections."""
# self.fusing_transformer = True
# self.transformer.fuse_qkv_projections()
# def unfuse_qkv_projections(self) -> None:
# r"""Disable QKV projection fusion if enabled."""
# if not self.fusing_transformer:
# logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
# else:
# self.transformer.unfuse_qkv_projections()
# self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
@ -365,8 +334,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
height: int = 480,
width: int = 720,
num_frames: int = 48,
t_tile_length: int = 12,
t_tile_overlap: int = 4,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
@ -487,9 +454,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
num_frames += self.additional_frames * self.vae_scale_factor_temporal
#if latents is None and num_frames == t_tile_length:
# num_frames += 1
if self.original_mask is not None:
image_latents = latents
original_image_latents = image_latents
@ -569,23 +533,16 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7. context schedule and temporal tiling
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames
t_tile_overlap = context_overlap
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
use_temporal_tiling = True
logger.info("Temporal tiling enabled")
elif context_schedule is not None:
if context_schedule is not None:
if image_cond_latents is not None:
raise NotImplementedError("Context schedule not currently supported with image conditioning")
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
from .cogvideox_fun.context import get_context_scheduler
context = get_context_scheduler(context_schedule)
#todo ofs embeds?
else:
use_temporal_tiling = False
use_context_schedule = False
logger.info("Temporal tiling and context schedule disabled")
# 7.5. Create rotary embeds if required
@ -647,100 +604,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# =====================================================
grid_ts = 0
cur_t = 0
while cur_t < latents.shape[1]:
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
grid_ts += 1
all_t = latents.shape[1]
latents_all_list = []
# =====================================================
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
#t_input = t[None].to(device)
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input_tile,
encoder_hidden_states=prompt_embeds,
timestep=t_input,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents_all_list.append(latents_tile)
# ==========================================
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
# Add each tile contribution to overall latents
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
latents_all /= contributors
latents = latents_all
#print("latents",latents.shape)
# start diff diff
if i < len(timesteps) - 1 and self.original_mask is not None:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
)
mask = mask.to(latents)
ts_from = timesteps[0]
ts_to = timesteps[-1]
threshold = (t - ts_to) / (ts_from - ts_to)
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# ==========================================
elif use_context_schedule:
# region context schedule sampling
if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
counter = torch.zeros_like(latent_model_input)
@ -858,7 +723,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# region sampling
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@ -1,621 +0,0 @@
# Adapted from CogVideo
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# CogVideo: https://github.com/THUDM/CogVideo
# diffusers: https://github.com/huggingface/diffusers
# --------------------------------------------------------
from typing import Any, Dict, Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import is_torch_version
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
from .core.pab_mgr import enable_pab, if_broadcast_spatial
from .modules.embeddings import apply_rotary_emb
#from .modules.embeddings import CogVideoXPatchEmbed
from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
try:
from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True
except:
SAGEATTN_IS_AVAVILABLE = False
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
attn_heads = attn.heads
inner_dim = key.shape[-1]
head_dim = inner_dim // attn_heads
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
emb_len = image_rotary_emb[0].shape[0]
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
block_idx: int = 0,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# parallel
#self.attn1.parallel_manager = None
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# pab
self.attn_count = 0
self.last_attn = None
self.block_idx = block_idx
#@torch.compiler.disable()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep=None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
# attention
if enable_pab():
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
if enable_pab() and broadcast_attn:
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if enable_pab():
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
self.fuser_list = None
# parallel
#self.parallel_manager = None
# def enable_parallel(self, dp_size, sp_size, enable_cp):
# # update cfg parallel
# if enable_cp and sp_size % 2 == 0:
# sp_size = sp_size // 2
# cp_size = 2
# else:
# cp_size = 1
# self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
# for _, module in self.named_modules():
# if hasattr(module, "parallel_manager"):
# module.parallel_manager = self.parallel_manager
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
video_flow_features: Optional[torch.Tensor] = None,
):
# if self.parallel_manager.cp_size > 1:
# (
# hidden_states,
# encoder_hidden_states,
# timestep,
# timestep_cond,
# image_rotary_emb,
# ) = batch_func(
# partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
# hidden_states,
# encoder_hidden_states,
# timestep,
# timestep_cond,
# image_rotary_emb,
# )
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# if self.parallel_manager.sp_size > 1:
# set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
# hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
timestep=timesteps if enable_pab() else None,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
#if self.parallel_manager.sp_size > 1:
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
#if self.parallel_manager.cp_size > 1:
# output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -1,232 +0,0 @@
PAB_MANAGER = None
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
spatial_broadcast: bool = False,
spatial_threshold: list = None,
spatial_range: int = None,
temporal_broadcast: bool = False,
temporal_threshold: list = None,
temporal_range: int = None,
mlp_broadcast: bool = False,
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_range = temporal_range
self.mlp_broadcast = mlp_broadcast
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
self.mlp_temporal_outputs = {}
self.mlp_spatial_outputs = {}
class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
print(init_prompt)
def if_broadcast_cross(self, timestep: int, count: int):
if (
self.config.cross_broadcast
and (timestep is not None)
and (count % self.config.cross_range != 0)
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_temporal(self, timestep: int, count: int):
if (
self.config.temporal_broadcast
and (timestep is not None)
and (count % self.config.temporal_range != 0)
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
if (
self.config.spatial_broadcast
and (timestep is not None)
and (count % self.config.spatial_range != 0)
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
@staticmethod
def _is_t_in_skip_config(all_timesteps, timestep, config):
is_t_in_skip_config = False
skip_range = None
for key in config:
if key not in all_timesteps:
continue
index = all_timesteps.index(key)
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
if timestep in skip_range:
is_t_in_skip_config = True
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
break
return is_t_in_skip_config, skip_range
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not self.config.mlp_broadcast:
return False, None, False, None
if is_temporal:
cur_config = self.config.mlp_temporal_broadcast_config
else:
cur_config = self.config.mlp_spatial_broadcast_config
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
next_flag = False
if (
self.config.mlp_broadcast
and (timestep is not None)
and (timestep in cur_config)
and (block_idx in cur_config[timestep]["block"])
):
flag = False
next_flag = True
count = count + 1
elif (
self.config.mlp_broadcast
and (timestep is not None)
and (is_t_in_skip_config)
and (block_idx in cur_config[skip_range[0]]["block"])
):
flag = True
count = 0
else:
flag = False
return flag, count, next_flag, skip_range
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
if is_temporal:
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
else:
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
skip_start_t = skip_range[0]
if is_temporal:
skip_output = (
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_temporal_outputs is not None
else None
)
else:
skip_output = (
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_spatial_outputs is not None
else None
)
if skip_output is not None:
if timestep == skip_range[-1]:
# TODO: save memory
if is_temporal:
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
else:
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
else:
raise ValueError(
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
)
return skip_output
def get_spatial_mlp_outputs(self):
return self.config.mlp_spatial_outputs
def get_temporal_mlp_outputs(self):
return self.config.mlp_temporal_outputs
def set_pab_manager(config: PABConfig):
global PAB_MANAGER
PAB_MANAGER = PABManager(config)
def enable_pab():
if PAB_MANAGER is None:
return False
return (
PAB_MANAGER.config.cross_broadcast
or PAB_MANAGER.config.spatial_broadcast
or PAB_MANAGER.config.temporal_broadcast
)
def update_steps(steps: int):
if PAB_MANAGER is not None:
PAB_MANAGER.config.steps = steps
def if_broadcast_cross(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_cross(timestep, count)
def if_broadcast_temporal(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not enable_pab():
return False, count
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)

View File

@ -1,44 +0,0 @@
import inspect
from abc import abstractmethod
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
class VideoSysPipeline(DiffusionPipeline):
def __init__(self):
super().__init__()
@staticmethod
def set_eval_and_device(device: torch.device, *modules):
for module in modules:
module.eval()
module.to(device)
@abstractmethod
def generate(self, *args, **kwargs):
pass
def __call__(self, *args, **kwargs):
"""
In diffusers, it is a convention to call the pipeline object.
But in VideoSys, we will use the generate method for better prompt.
This is a wrapper for the generate method to support the diffusers usage.
"""
return self.generate(*args, **kwargs)
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
# modify: remove the config module from the expected modules
expected_modules = expected_modules - {"config"}
optional_names = list(optional_parameters)
for name in optional_names:
if name in cls._optional_components:
expected_modules.add(name)
optional_parameters.remove(name)
return expected_modules, optional_parameters

View File

@ -1,3 +0,0 @@
import torch.nn as nn
approx_gelu = lambda: nn.GELU(approximate="tanh")

View File

@ -1,71 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class CogVideoXDownsample3D(nn.Module):
# Todo: Wait for paper relase.
r"""
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `2`):
Stride of the convolution.
padding (`int`, defaults to `0`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x

View File

@ -1,308 +0,0 @@
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
class OpenSoraPatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs // s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
s_emb = self.mlp(s_freq)
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
@property
def dtype(self):
return next(self.parameters()).dtype
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Use for example in Lumina
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Use for example in Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)

View File

@ -1,85 +0,0 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
x = self.norm(x) * (1 + scale) + shift
return x

View File

@ -1,67 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class CogVideoXUpsample3D(nn.Module):
r"""
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `1`):
Stride of the convolution.
padding (`int`, defaults to `1`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = F.interpolate(x_first, scale_factor=2.0)
x_rest = F.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = F.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs

View File

@ -1,64 +0,0 @@
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
spatial_broadcast: bool = False,
spatial_threshold: list = None,
spatial_range: int = None,
temporal_broadcast: bool = False,
temporal_threshold: list = None,
temporal_range: int = None,
mlp_broadcast: bool = False,
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_range = temporal_range
self.mlp_broadcast = mlp_broadcast
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
self.mlp_temporal_outputs = {}
self.mlp_spatial_outputs = {}
class CogVideoXPABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
temporal_broadcast: bool = False,
temporal_threshold: list = [100, 850],
temporal_range: int = 4,
cross_broadcast: bool = False,
cross_threshold: list = [100, 850],
cross_range: int = 6,
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range
)