Merge branch 'refactor'

This commit is contained in:
kijai 2024-11-19 19:16:39 +02:00
commit 67f2f6abb1
46 changed files with 11420 additions and 15441 deletions

3
.gitignore vendored
View File

@ -7,4 +7,5 @@ master_ip
logs/
*.DS_Store
.idea
*.pt
*.pt
tools/

File diff suppressed because it is too large Load Diff

View File

@ -1,741 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import os
import json
import torch
import glob
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import Attention, FeedForward
#from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
#from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from ..videosys.modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from ..videosys.modules.embeddings import apply_rotary_emb
from ..videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True
logger.info("Using sageattn")
except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAVILABLE = False
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# if attn.parallel_manager.sp_size > 1:
# assert (
# attn.heads % attn.parallel_manager.sp_size == 0
# ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
# attn_heads = attn.heads // attn.parallel_manager.sp_size
# query, key, value = map(
# lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
# [query, key, value],
# )
attn_heads = attn.heads
inner_dim = key.shape[-1]
head_dim = inner_dim // attn_heads
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
emb_len = image_rotary_emb[0].shape[0]
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
#if attn.parallel_manager.sp_size > 1:
# hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
block_idx: int = 0,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# parallel
#self.attn1.parallel_manager = None
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# pab
self.attn_count = 0
self.last_attn = None
self.block_idx = block_idx
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep=None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
if enable_pab():
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
if enable_pab() and broadcast_attn:
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if enable_pab():
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
add_noise_in_inpaint_model: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
self.post_patch_height = post_patch_height
self.post_patch_width = post_patch_width
self.post_time_compression_frames = post_time_compression_frames
self.patch_size = patch_size
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 2)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
pos_embeds = self.pos_embedding
emb_size = hidden_states.size()[-1]
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
timestep=timestep,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **transformer_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for model_file_safetensors in model_files_safetensors:
_state_dict = load_file(model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
else:
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
else:
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
print(m)
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
return model

View File

@ -1,974 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import BaseOutput, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.image_processor import VaeImageProcessor
from einops import rearrange
from ..videosys.core.pipeline import VideoSysPipeline
from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
from ..videosys.core.pab_mgr import set_pab_manager
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CogVideoX_Fun_Pipeline
>>> from diffusers.utils import export_to_video
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
>>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
... "atmosphere of this unique musical performance."
... )
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass
class CogVideoX_Fun_PipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
videos: torch.Tensor
class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
_optional_components = []
model_cpu_offload_seq = "vae->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
pab_config = None
):
super().__init__()
self.register_modules(
vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
if pab_config is not None:
set_pab_manager(pab_config)
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
latents=None, freenoise=True, context_size=None, context_overlap=None
):
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype)
if freenoise:
print("Applying FreeNoise")
# code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
video_length = num_frames // 4
delta = context_size - context_overlap
for start_idx in range(0, video_length-context_size, delta):
# start_idx corresponds to the beginning of a context window
# goal: place shuffled in the delta region right after the end of the context window
# if space after context window is not enough to place the noise, adjust and finish
place_idx = start_idx + context_size
# if place_idx is outside the valid indexes, we are already finished
if place_idx >= video_length:
break
end_idx = place_idx - 1
#print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta)
# if there is not enough room to copy delta amount of indexes, copy limited amount and finish
if end_idx + delta >= video_length:
final_delta = video_length - place_idx
# generate list of indexes in final delta region
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
# shuffle list
list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
# apply shuffled indexes
noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :]
break
# otherwise, do normal behavior
# generate list of indexes in delta region
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
# shuffle list
list_idx = list_idx[torch.randperm(delta, generator=generator)]
# apply shuffled indexes
#print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx)
noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :]
if latents is None:
latents = noise.to(device)
else:
latents = latents.to(device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
latent_timestep = timesteps[:1]
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
frames_needed = noise.shape[1]
current_frames = latents.shape[1]
if frames_needed > current_frames:
repeat_factor = frames_needed // current_frames
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
latents = torch.cat((latents, additional_frame), dim=1)
elif frames_needed < current_frames:
latents = latents[:, :frames_needed, :, :, :]
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
return latents, timesteps, noise
def prepare_control_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
if mask is not None:
mask = mask.to(device=device, dtype=self.vae.dtype)
bs = 1
new_mask = []
for i in range(0, mask.shape[0], bs):
mask_bs = mask[i : i + bs]
mask_bs = self.vae.encode(mask_bs)[0]
mask_bs = mask_bs.mode()
new_mask.append(mask_bs)
mask = torch.cat(new_mask, dim = 0)
mask = mask * self.vae.config.scaling_factor
if masked_image is not None:
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
bs = 1
new_mask_pixel_values = []
for i in range(0, masked_image.shape[0], bs):
mask_pixel_values_bs = masked_image[i : i + bs]
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
mask_pixel_values_bs = mask_pixel_values_bs.mode()
new_mask_pixel_values.append(mask_pixel_values_bs)
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
else:
masked_image_latents = None
return mask, masked_image_latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
frames = (frames / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
frames = frames.cpu().float().numpy()
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def _gaussian_weights(self, t_tile_length, t_batch_size):
from numpy import pi, exp, sqrt
var = 0.01
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
weights = torch.tensor(t_probs)
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
context_frames: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
if start_frame is not None or context_frames is not None:
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
if context_frames is not None:
freqs_cos = freqs_cos[context_frames]
freqs_sin = freqs_sin[context_frames]
else:
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
video: Union[torch.FloatTensor] = None,
control_video: Union[torch.FloatTensor] = None,
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
denoise_strength: float = 1.0,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "numpy",
return_dict: bool = False,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
comfyui_progressbar: bool = False,
control_strength: float = 1.0,
control_start_percent: float = 0.0,
control_end_percent: float = 1.0,
scheduler_name: str = "DPM",
context_schedule: Optional[str] = None,
context_frames: Optional[int] = None,
context_stride: Optional[int] = None,
context_overlap: Optional[int] = None,
freenoise: Optional[bool] = True,
tora: Optional[dict] = None,
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# if num_frames > 49:
# raise ValueError(
# "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
# )
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
if comfyui_progressbar:
from comfy.utils import ProgressBar
pbar = ProgressBar(num_inference_steps + 2)
# 5. Prepare latents.
latent_channels = self.vae.config.latent_channels
latents, timesteps, noise = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
self.vae.dtype,
device,
generator,
timesteps,
denoise_strength,
num_inference_steps,
latents,
context_size=context_frames,
context_overlap=context_overlap,
freenoise=freenoise,
)
if comfyui_progressbar:
pbar.update(1)
control_video_latents_input = (
torch.cat([control_video] * 2) if do_classifier_free_guidance else control_video
)
control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
control_latents = control_latents * control_strength
if comfyui_progressbar:
pbar.update(1)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.5. Temporal tiling prep
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames
t_tile_overlap = context_overlap
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
use_temporal_tiling = True
print("Temporal tiling enabled")
elif context_schedule is not None:
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
else:
use_temporal_tiling = False
use_context_schedule = False
print("Temporal tiling and context schedule disabled")
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
if tora is not None:
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# =====================================================
grid_ts = 0
cur_t = 0
while cur_t < latents.shape[1]:
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
grid_ts += 1
all_t = latents.shape[1]
latents_all_list = []
# =====================================================
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :]
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
#t_input = t[None].to(device)
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input_tile,
encoder_hidden_states=prompt_embeds,
timestep=t_input,
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=control_latents_tile,
)[0]
noise_pred = noise_pred.float()
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents_all_list.append(latents_tile)
# ==========================================
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
# Add each tile contribution to overall latents
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
latents_all /= contributors
latents = latents_all
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
pbar.update(1)
# ==========================================
elif use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Calculate the current step percentage
current_step_percentage = i / num_inference_steps
# Determine if control_latents should be applied
apply_control = control_start_percent <= current_step_percentage <= control_end_percent
current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
context_queue = list(context(
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
))
counter = torch.zeros_like(latent_model_input)
noise_pred = torch.zeros_like(latent_model_input)
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
partial_control_latents = current_control_latents[:, c, :, :, :]
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
hidden_states=partial_latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=partial_control_latents,
)[0]
counter[:, c, :, :, :] += 1
noise_pred = noise_pred.float()
noise_pred /= counter
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if comfyui_progressbar:
pbar.update(1)
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Calculate the current step percentage
current_step_percentage = i / num_inference_steps
# Determine if control_latents should be applied
apply_control = control_start_percent <= current_step_percentage <= control_end_percent
current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=current_control_latents,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0]
noise_pred = noise_pred.float()
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if comfyui_progressbar:
pbar.update(1)
# if output_type == "numpy":
# video = self.decode_latents(latents)
# elif not output_type == "latent":
# video = self.decode_latents(latents)
# video = self.video_processor.postprocess_video(video=video, output_type=output_type)
# else:
# video = latents
# Offload all models
self.maybe_free_model_hooks()
# if not return_dict:
# video = torch.from_numpy(video)
return latents

File diff suppressed because it is too large Load Diff

View File

@ -1,873 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import os
import json
import torch
import glob
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import AttentionProcessor#, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from einops import rearrange
try:
from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True
logger.info("Using sageattn")
except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAVILABLE = False
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
B, C, H, W = tensor.size()
radius = min(H, W) // 5
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
center_x, center_y = W // 2, H // 2
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
high_freq_mask = ~low_freq_mask
low_freq_fft = tensor_fft_shifted * low_freq_mask
high_freq_fft = tensor_fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
#fastercache
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
add_noise_in_inpaint_model: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
self.post_patch_height = post_patch_height
self.post_patch_width = post_patch_width
self.post_time_compression_frames = post_time_compression_frames
self.patch_size = patch_size
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
self.fuser_list = None
self.use_fastercache = False
self.fastercache_counter = 0
self.fastercache_start_step = 15
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_features: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 2)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
pos_embeds = self.pos_embedding
emb_size = hidden_states.size()[-1]
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if self.use_fastercache:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states[:1],
encoder_hidden_states=encoder_hidden_states[:1],
temb=emb[:1],
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb[:1])
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(1, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
#lf_step = 40
#hf_step = 30
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
self.delta_hf = self.delta_hf * 1.1
new_hf_uc = self.delta_hf + hf_c
new_lf_uc = self.delta_lf + lf_c
combine_uc = new_lf_uc + new_hf_uc
combined_fft = torch.fft.ifftshift(combine_uc)
recovered_uncond = torch.fft.ifft2(combined_fft).real
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
output = torch.cat([output, recovered_uncond])
else:
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond)
lf_uc, hf_uc = fft(uncond)
self.delta_lf = lf_uc - lf_c
self.delta_hf = hf_uc - hf_c
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **transformer_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for model_file_safetensors in model_files_safetensors:
_state_dict = load_file(model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
else:
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
else:
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
print(m)
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
return model

View File

@ -1,26 +1,6 @@
import os
import gc
import numpy as np
import torch
from PIL import Image
# Copyright (c) OpenMMLab. All rights reserved.
def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
def numpy2pil(image):
return Image.fromarray(np.clip(255. * image, 0, 255).astype(np.uint8))
def to_pil(image):
if isinstance(image, Image.Image):
return image
if isinstance(image, torch.Tensor):
return tensor2pil(image)
if isinstance(image, np.ndarray):
return numpy2pil(image)
raise ValueError(f"Cannot convert {type(image)} to PIL.Image")
ASPECT_RATIO_512 = {
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
@ -54,126 +34,10 @@ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_5
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return ratios[closest_ratio], float(closest_ratio)
def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
target_pixels = int(base_resolution) * int(base_resolution)
original_width, original_height = Image.open(image).size
ratio = (target_pixels / (original_width * original_height)) ** 0.5
width_slider = round(original_width * ratio)
height_slider = round(original_height * ratio)
return height_slider, width_slider
def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
if validation_image_start is not None and validation_image_end is not None:
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
image_start = clip_image = Image.open(validation_image_start).convert("RGB")
image_start = image_start.resize([sample_size[1], sample_size[0]])
clip_image = clip_image.resize([sample_size[1], sample_size[0]])
else:
image_start = clip_image = validation_image_start
image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
if type(validation_image_end) is str and os.path.isfile(validation_image_end):
image_end = Image.open(validation_image_end).convert("RGB")
image_end = image_end.resize([sample_size[1], sample_size[0]])
else:
image_end = validation_image_end
image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
if type(image_start) is list:
clip_image = clip_image[0]
start_video = torch.cat(
[torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
dim=2
)
input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
input_video[:, :, :len(image_start)] = start_video
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, len(image_start):] = 255
else:
input_video = torch.tile(
torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
[1, 1, video_length, 1, 1]
)
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, 1:] = 255
if type(image_end) is list:
image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
end_video = torch.cat(
[torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
dim=2
)
input_video[:, :, -len(end_video):] = end_video
input_video_mask[:, :, -len(image_end):] = 0
else:
image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
input_video_mask[:, :, -1:] = 0
input_video = input_video / 255
elif validation_image_start is not None:
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
image_start = clip_image = Image.open(validation_image_start).convert("RGB")
image_start = image_start.resize([sample_size[1], sample_size[0]])
clip_image = clip_image.resize([sample_size[1], sample_size[0]])
else:
image_start = clip_image = validation_image_start
image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
image_end = None
if type(image_start) is list:
clip_image = clip_image[0]
start_video = torch.cat(
[torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
dim=2
)
input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
input_video[:, :, :len(image_start)] = start_video
input_video = input_video / 255
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, len(image_start):] = 255
else:
input_video = torch.tile(
torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
[1, 1, video_length, 1, 1]
) / 255
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, 1:, ] = 255
else:
image_start = None
image_end = None
input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
clip_image = None
del image_start
del image_end
gc.collect()
return input_video, input_video_mask, clip_image
def get_video_to_video_latent(input_video_path, video_length, sample_size, validation_video_mask=None):
input_video = input_video_path
input_video = torch.from_numpy(np.array(input_video))[:video_length]
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
if validation_video_mask is not None:
validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
else:
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, :] = 255
return input_video, input_video_mask, None
return height_slider, width_slider

View File

@ -27,23 +27,28 @@ from diffusers.utils import logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.embeddings import apply_rotary_emb
from .embeddings import CogVideoXPatchEmbed
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
from sageattention import sageattn
SAGEATTN_IS_AVAILABLE = True
logger.info("Using sageattn")
except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAILABLE = False
@torch.compiler.disable()
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
@ -70,7 +75,7 @@ class CogVideoXAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
@ -78,6 +83,7 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mode: Optional[str] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@ -91,9 +97,14 @@ class CogVideoXAttnProcessor2_0:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
else:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@ -109,18 +120,18 @@ class CogVideoXAttnProcessor2_0:
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
#if torch.isinf(hidden_states).any():
# raise ValueError(f"hidden_states after dot product has inf")
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@ -134,80 +145,7 @@ class CogVideoXAttnProcessor2_0:
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
#region Blocks
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
@ -266,6 +204,7 @@ class CogVideoXBlock(nn.Module):
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
@ -300,15 +239,21 @@ class CogVideoXBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
block_use_fastercache=False,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
attention_mode="sdpa",
) -> torch.Tensor:
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
#print("norm_hidden_states in block: ", norm_hidden_states.shape) #torch.Size([2, 3200, 3072])
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
@ -318,36 +263,46 @@ class CogVideoXBlock(nn.Module):
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
#fastercache
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
#region fastercache
if block_use_fastercache:
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
@ -361,7 +316,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
#region Transformer
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@ -428,6 +383,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
ofs_embed_dim: Optional[int] = None,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
@ -436,6 +392,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
patch_size_t: int = None,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
@ -446,6 +403,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -460,10 +418,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
bias=patch_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
@ -480,6 +439,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
self.ofs_proj = None
self.ofs_embedding = None
if ofs_embed_dim:
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
@ -507,7 +473,14 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
if patch_size_t is None:
# For CogVideox 1.0
output_dim = patch_size * patch_size * out_channels
else:
# For CogVideoX 1.5
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = nn.Linear(inner_dim, output_dim)
self.gradient_checkpointing = False
@ -518,6 +491,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = "sdpa"
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@ -582,45 +558,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
@ -628,6 +565,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
@ -635,7 +573,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
@ -644,40 +582,57 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.ofs_embedding is not None: #1.5 I2V
ofs_emb = self.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = self.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# 2. Patch embedding
p = self.config.patch_size
p_t = self.config.patch_size_t
#print("hidden_states before patch_embedding", hidden_states.shape) #torch.Size([2, 4, 16, 60, 90])
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
#print("hidden_states after patch_embedding", hidden_states.shape) #1.5: torch.Size([2, 2926, 3072]) #1.0: torch.Size([2, 5626, 3072])
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
#print("hidden_states after split", hidden_states.shape) #1.5: torch.Size([2, 2700, 3072]) #1.0: torch.Size([2, 5400, 3072])
if self.use_fastercache:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states[:1],
encoder_hidden_states=encoder_hidden_states[:1],
temb=emb[:1],
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states[:1],
encoder_hidden_states=encoder_hidden_states[:1],
temb=emb[:1],
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
@ -696,9 +651,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if p_t is None:
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
1, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
@ -727,19 +688,26 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
)
#has_nan = torch.isnan(hidden_states).any()
#if has_nan:
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
#controlnet
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
print(controlnet_block_weight)
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
@ -758,9 +726,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape

226
embeddings.py Normal file
View File

@ -0,0 +1,226 @@
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple, Union, Optional
from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_1d_rotary_pos_embed
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
patch_size_t: Optional[int] = None,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_positional_embeddings: bool = True,
use_learned_positional_embeddings: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
if patch_size_t is None:
# CogVideoX 1.0 checkpoints
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
else:
# CogVideoX 1.5 checkpoints
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings or use_learned_positional_embeddings:
persistent = use_learned_positional_embeddings
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
return joint_pos_embedding
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch_size, num_frames, channels, height, width = image_embeds.shape
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding
embeds = embeds + pos_embedding
return embeds
def get_3d_rotary_pos_embed(
embed_dim,
crops_coords,
grid_size,
temporal_size,
theta: int = 10000,
use_real: bool = True,
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
grid_type (`str`):
Whether to use "linspace" or "slice" to compute grids.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
if grid_type == "linspace":
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
elif grid_type == "slice":
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)
grid_w = np.arange(max_w, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
else:
raise ValueError("Invalid value passed for `grid_type`.")
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
freqs_t = freqs_t[:, None, None, :].expand(
-1, grid_size_h, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h = freqs_h[None, :, None, :].expand(
temporal_size, -1, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w = freqs_w[None, None, :, :].expand(
temporal_size, grid_size_h, -1, -1
) # temporal_size, grid_size_h, grid_size_2, dim_w
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
if grid_type == "slice":
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin

View File

@ -1,561 +0,0 @@
{
"last_node_id": 34,
"last_link_id": 61,
"nodes": [
{
"id": 33,
"type": "GetImageSizeAndCount",
"pos": {
"0": 1176,
"1": 122
},
"size": {
"0": 210,
"1": 86
},
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 59
}
],
"outputs": [
{
"name": "image",
"type": "IMAGE",
"links": [
60
],
"slot_index": 0,
"shape": 3
},
{
"name": "720 width",
"type": "INT",
"links": null,
"shape": 3
},
{
"name": "480 height",
"type": "INT",
"links": null,
"shape": 3
},
{
"name": "104 count",
"type": "INT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "GetImageSizeAndCount"
},
"widgets_values": []
},
{
"id": 30,
"type": "CogVideoTextEncode",
"pos": {
"0": 500,
"1": 308
},
"size": [
474.8035864085422,
211.10369504535595
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 54
}
],
"outputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"links": [
55
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature\nacoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters\nthrough the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The\nbackground includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical\nperformance.",
1,
true
]
},
{
"id": 31,
"type": "CogVideoTextEncode",
"pos": {
"0": 508,
"1": 576
},
"size": {
"0": 463.01251220703125,
"1": 124
},
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 56
}
],
"outputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"links": [
57
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"",
1,
true
]
},
{
"id": 20,
"type": "CLIPLoader",
"pos": {
"0": -37,
"1": 443
},
"size": {
"0": 451.30548095703125,
"1": 82
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
54,
56
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CLIPLoader"
},
"widgets_values": [
"t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors",
"sd3"
]
},
{
"id": 11,
"type": "CogVideoDecode",
"pos": {
"0": 1045,
"1": 776
},
"size": {
"0": 295.70111083984375,
"1": 198
},
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 37
},
{
"name": "samples",
"type": "LATENT",
"link": 38
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"links": [
59
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoDecode"
},
"widgets_values": [
true,
96,
96,
0.083,
0.083,
true
]
},
{
"id": 1,
"type": "DownloadAndLoadCogVideoModel",
"pos": {
"0": 652,
"1": 43
},
"size": {
"0": 315,
"1": 194
},
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"name": "pab_config",
"type": "PAB_CONFIG",
"link": null
},
{
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"link": null
},
{
"name": "lora",
"type": "COGLORA",
"link": null
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"links": [
36
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoModel"
},
"widgets_values": [
"THUDM/CogVideoX-2b",
"fp16",
"enabled",
"disabled",
false
]
},
{
"id": 32,
"type": "VHS_VideoCombine",
"pos": {
"0": 1439,
"1": 122
},
"size": [
563.3333740234375,
686.2222493489583
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 60,
"slot_index": 0
},
{
"name": "audio",
"type": "VHS_AUDIO",
"link": null
},
{
"name": "meta_batch",
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 8,
"loop_count": 0,
"filename_prefix": "CogVideo2B_long",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"pingpong": false,
"save_output": false,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "CogVideo2B_long_00005.mp4",
"subfolder": "",
"type": "temp",
"format": "video/h264-mp4",
"frame_rate": 8
}
}
}
},
{
"id": 34,
"type": "CogVideoContextOptions",
"pos": {
"0": 1053,
"1": -84
},
"size": {
"0": 315,
"1": 154
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "context_options",
"type": "COGCONTEXT",
"links": [
61
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoContextOptions"
},
"widgets_values": [
"uniform_standard",
52,
4,
8,
true
]
},
{
"id": 22,
"type": "CogVideoSampler",
"pos": {
"0": 1041,
"1": 342
},
"size": {
"0": 315,
"1": 382
},
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 36
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 55,
"slot_index": 1
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 57
},
{
"name": "samples",
"type": "LATENT",
"link": null
},
{
"name": "image_cond_latents",
"type": "LATENT",
"link": null
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": 61
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"links": [
37
],
"shape": 3
},
{
"name": "samples",
"type": "LATENT",
"links": [
38
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoSampler"
},
"widgets_values": [
480,
720,
104,
32,
6,
42,
"fixed",
"CogVideoXDDIM",
1
]
}
],
"links": [
[
36,
1,
0,
22,
0,
"COGVIDEOPIPE"
],
[
37,
22,
0,
11,
0,
"COGVIDEOPIPE"
],
[
38,
22,
1,
11,
1,
"LATENT"
],
[
54,
20,
0,
30,
0,
"CLIP"
],
[
55,
30,
0,
22,
1,
"CONDITIONING"
],
[
56,
20,
0,
31,
0,
"CLIP"
],
[
57,
31,
0,
22,
2,
"CONDITIONING"
],
[
59,
11,
0,
33,
0,
"IMAGE"
],
[
60,
33,
0,
32,
0,
"IMAGE"
],
[
61,
34,
0,
22,
5,
"COGCONTEXT"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.8390545288825444,
"offset": [
-14.198557467892236,
144.90015432747748
]
}
},
"version": 0.4
}

View File

@ -1,42 +1,7 @@
{
"last_node_id": 58,
"last_link_id": 129,
"last_node_id": 63,
"last_link_id": 149,
"nodes": [
{
"id": 20,
"type": "CLIPLoader",
"pos": {
"0": -26,
"1": 400
},
"size": {
"0": 451.30548095703125,
"1": 82
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
54,
56
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CLIPLoader"
},
"widgets_values": [
"t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors",
"sd3"
]
},
{
"id": 31,
"type": "CogVideoTextEncode",
@ -46,16 +11,16 @@
},
"size": {
"0": 463.01251220703125,
"1": 124
"1": 144
},
"flags": {},
"order": 4,
"order": 6,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 56
"link": 149
}
],
"outputs": [
@ -63,10 +28,15 @@
"name": "conditioning",
"type": "CONDITIONING",
"links": [
123
146
],
"slot_index": 0,
"shape": 3
},
{
"name": "clip",
"type": "CLIP",
"links": null
}
],
"properties": {
@ -78,6 +48,208 @@
true
]
},
{
"id": 63,
"type": "CogVideoSampler",
"pos": {
"0": 1142,
"1": 74
},
"size": [
330,
574
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "COGVIDEOMODEL",
"link": 144
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 145
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 146
},
{
"name": "samples",
"type": "LATENT",
"link": null,
"shape": 7
},
{
"name": "image_cond_latents",
"type": "LATENT",
"link": 147,
"shape": 7
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null,
"shape": 7
},
{
"name": "controlnet",
"type": "COGVIDECONTROLNET",
"link": null,
"shape": 7
},
{
"name": "tora_trajectory",
"type": "TORAFEATURES",
"link": null,
"shape": 7
},
{
"name": "fastercache",
"type": "FASTERCACHEARGS",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "samples",
"type": "LATENT",
"links": [
148
]
}
],
"properties": {
"Node name for S&R": "CogVideoSampler"
},
"widgets_values": [
49,
25,
6,
0,
"fixed",
"CogVideoXDDIM",
1
]
},
{
"id": 62,
"type": "CogVideoImageEncode",
"pos": {
"0": 1149,
"1": 711
},
"size": {
"0": 315,
"1": 122
},
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "vae",
"type": "VAE",
"link": 141
},
{
"name": "start_image",
"type": "IMAGE",
"link": 142
},
{
"name": "end_image",
"type": "IMAGE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "samples",
"type": "LATENT",
"links": [
147
]
}
],
"properties": {
"Node name for S&R": "CogVideoImageEncode"
},
"widgets_values": [
false,
0
]
},
{
"id": 59,
"type": "DownloadAndLoadCogVideoModel",
"pos": {
"0": 622,
"1": -25
},
"size": {
"0": 315,
"1": 218
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"link": null,
"shape": 7
},
{
"name": "lora",
"type": "COGLORA",
"link": null,
"shape": 7
},
{
"name": "compile_args",
"type": "COMPILEARGS",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "model",
"type": "COGVIDEOMODEL",
"links": [
144
]
},
{
"name": "vae",
"type": "VAE",
"links": [
132,
141
],
"slot_index": 1
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoModel"
},
"widgets_values": [
"THUDM/CogVideoX-5b-I2V",
"bf16",
"disabled",
false,
"sdpa",
"main_device"
]
},
{
"id": 30,
"type": "CogVideoTextEncode",
@ -90,7 +262,7 @@
"1": 168.08047485351562
},
"flags": {},
"order": 3,
"order": 4,
"mode": 0,
"inputs": [
{
@ -104,10 +276,18 @@
"name": "conditioning",
"type": "CONDITIONING",
"links": [
122
145
],
"slot_index": 0,
"shape": 3
},
{
"name": "clip",
"type": "CLIP",
"links": [
149
],
"slot_index": 1
}
],
"properties": {
@ -116,22 +296,22 @@
"widgets_values": [
"a majestic stag is grazing in an enhanced forest, basking in the setting sun filtered by the trees",
1,
true
false
]
},
{
"id": 37,
"type": "ImageResizeKJ",
"pos": {
"0": 809,
"1": 684
"0": 784,
"1": 731
},
"size": {
"0": 315,
"1": 266
},
"flags": {},
"order": 5,
"order": 3,
"mode": 0,
"inputs": [
{
@ -142,7 +322,8 @@
{
"name": "get_image_size",
"type": "IMAGE",
"link": null
"link": null,
"shape": 7
},
{
"name": "width_input",
@ -166,7 +347,7 @@
"name": "IMAGE",
"type": "IMAGE",
"links": [
125
142
],
"slot_index": 0,
"shape": 3
@ -199,64 +380,88 @@
]
},
{
"id": 58,
"type": "CogVideoImageEncode",
"id": 36,
"type": "LoadImage",
"pos": {
"0": 1156,
"1": 650
"0": 335,
"1": 731
},
"size": {
"0": 315,
"1": 122
"0": 402.06353759765625,
"1": 396.6225891113281
},
"flags": {},
"order": 6,
"order": 1,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 124
},
{
"name": "image",
"type": "IMAGE",
"link": 125
},
{
"name": "mask",
"type": "MASK",
"link": null
}
],
"inputs": [],
"outputs": [
{
"name": "samples",
"type": "LATENT",
"name": "IMAGE",
"type": "IMAGE",
"links": [
129
71
],
"slot_index": 0,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"sd3stag.png",
"image"
]
},
{
"id": 20,
"type": "CLIPLoader",
"pos": {
"0": -2,
"1": 304
},
"size": {
"0": 451.30548095703125,
"1": 82
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
54
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoImageEncode"
"Node name for S&R": "CLIPLoader"
},
"widgets_values": [
16,
true
"t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors",
"sd3"
]
},
{
"id": 56,
"id": 60,
"type": "CogVideoDecode",
"pos": {
"0": 1581,
"1": 148
"0": 1523,
"1": -6
},
"size": {
"0": 300.396484375,
"0": 315,
"1": 198
},
"flags": {},
@ -264,14 +469,14 @@
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 128
"name": "vae",
"type": "VAE",
"link": 132
},
{
"name": "samples",
"type": "LATENT",
"link": 127
"link": 148
}
],
"outputs": [
@ -279,17 +484,15 @@
"name": "images",
"type": "IMAGE",
"links": [
118
],
"slot_index": 0,
"shape": 3
134
]
}
],
"properties": {
"Node name for S&R": "CogVideoDecode"
},
"widgets_values": [
false,
true,
240,
360,
0.2,
@ -301,8 +504,8 @@
"id": 44,
"type": "VHS_VideoCombine",
"pos": {
"0": 1927,
"1": 146
"0": 1884,
"1": -6
},
"size": [
605.3909912109375,
@ -315,22 +518,25 @@
{
"name": "images",
"type": "IMAGE",
"link": 118
"link": 134
},
{
"name": "audio",
"type": "AUDIO",
"link": null
"link": null,
"shape": 7
},
{
"name": "meta_batch",
"type": "VHS_BatchManager",
"link": null
"link": null,
"shape": 7
},
{
"name": "vae",
"type": "VAE",
"link": null
"link": null,
"shape": 7
}
],
"outputs": [
@ -367,180 +573,6 @@
"muted": false
}
}
},
{
"id": 36,
"type": "LoadImage",
"pos": {
"0": 365,
"1": 685
},
"size": {
"0": 402.06353759765625,
"1": 396.6225891113281
},
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
71
],
"slot_index": 0,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"sd3stag.png",
"image"
]
},
{
"id": 57,
"type": "CogVideoSampler",
"pos": {
"0": 1138,
"1": 150
},
"size": [
399.878095897654,
350
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 121
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 122
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 123
},
{
"name": "samples",
"type": "LATENT",
"link": null
},
{
"name": "image_cond_latents",
"type": "LATENT",
"link": 129
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"links": [
128
],
"slot_index": 0,
"shape": 3
},
{
"name": "samples",
"type": "LATENT",
"links": [
127
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoSampler"
},
"widgets_values": [
480,
720,
49,
20,
6,
65334758276105,
"fixed",
"CogVideoXDPMScheduler",
1
]
},
{
"id": 1,
"type": "DownloadAndLoadCogVideoModel",
"pos": {
"0": 633,
"1": 44
},
"size": {
"0": 337.8885192871094,
"1": 194
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "pab_config",
"type": "PAB_CONFIG",
"link": null
},
{
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"link": null
},
{
"name": "lora",
"type": "COGLORA",
"link": null
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"links": [
121,
124
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoModel"
},
"widgets_values": [
"THUDM/CogVideoX-5b-I2V",
"bf16",
"disabled",
"disabled",
false
]
}
],
"links": [
@ -552,14 +584,6 @@
0,
"CLIP"
],
[
56,
20,
0,
31,
0,
"CLIP"
],
[
71,
36,
@ -569,86 +593,94 @@
"IMAGE"
],
[
118,
56,
132,
59,
1,
60,
0,
"VAE"
],
[
134,
60,
0,
44,
0,
"IMAGE"
],
[
121,
141,
59,
1,
62,
0,
57,
0,
"COGVIDEOPIPE"
"VAE"
],
[
122,
30,
0,
57,
1,
"CONDITIONING"
],
[
123,
31,
0,
57,
2,
"CONDITIONING"
],
[
124,
1,
0,
58,
0,
"COGVIDEOPIPE"
],
[
125,
142,
37,
0,
58,
62,
1,
"IMAGE"
],
[
127,
57,
1,
56,
1,
"LATENT"
144,
59,
0,
63,
0,
"COGVIDEOMODEL"
],
[
128,
57,
145,
30,
0,
56,
0,
"COGVIDEOPIPE"
63,
1,
"CONDITIONING"
],
[
129,
58,
146,
31,
0,
57,
63,
2,
"CONDITIONING"
],
[
147,
62,
0,
63,
4,
"LATENT"
],
[
148,
63,
0,
60,
1,
"LATENT"
],
[
149,
30,
1,
31,
0,
"CLIP"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.6934334949442514,
"scale": 0.7627768444387059,
"offset": [
-24.154349208343916,
155.20539218330134
648.7113591814891,
185.9907078691075
]
}
},

File diff suppressed because one or more lines are too long

View File

@ -1,48 +1,7 @@
{
"last_node_id": 34,
"last_link_id": 64,
"last_node_id": 37,
"last_link_id": 72,
"nodes": [
{
"id": 31,
"type": "CogVideoTextEncode",
"pos": {
"0": 503,
"1": 521
},
"size": {
"0": 463.01251220703125,
"1": 124
},
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 56
}
],
"outputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"links": [
62
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"",
1,
true
]
},
{
"id": 30,
"type": "CogVideoTextEncode",
@ -50,12 +9,12 @@
"0": 500,
"1": 308
},
"size": {
"0": 471.90142822265625,
"1": 168.08047485351562
},
"size": [
470.99399664051055,
237.5088638951354
],
"flags": {},
"order": 2,
"order": 3,
"mode": 0,
"inputs": [
{
@ -69,10 +28,18 @@
"name": "conditioning",
"type": "CONDITIONING",
"links": [
61
67
],
"slot_index": 0,
"shape": 3
},
{
"name": "clip",
"type": "CLIP",
"links": [
65
],
"slot_index": 1
}
],
"properties": {
@ -81,192 +48,79 @@
"widgets_values": [
"A golden retriever, sporting sleek black sunglasses, with its lengthy fur flowing in the breeze, sprints playfully across a rooftop terrace, recently refreshed by a light rain. The scene unfolds from a distance, the dog's energetic bounds growing larger as it approaches the camera, its tail wagging with unrestrained joy, while droplets of water glisten on the concrete behind it. The overcast sky provides a dramatic backdrop, emphasizing the vibrant golden coat of the canine as it dashes towards the viewer.\n\n",
1,
true
false
]
},
{
"id": 33,
"type": "VHS_VideoCombine",
"id": 31,
"type": "CogVideoTextEncode",
"pos": {
"0": 1441,
"1": 129
"0": 503,
"1": 602
},
"size": [
778.7022705078125,
310
464.4980515341475,
169.87479027400514
],
"flags": {},
"order": 6,
"order": 4,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 59
},
{
"name": "audio",
"type": "AUDIO",
"link": null
},
{
"name": "meta_batch",
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 8,
"loop_count": 0,
"filename_prefix": "CogVideoX5B",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"pingpong": false,
"save_output": false,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "CogVideoX5B_00009.mp4",
"subfolder": "",
"type": "temp",
"format": "video/h264-mp4",
"frame_rate": 8
},
"muted": false
}
}
},
{
"id": 20,
"type": "CLIPLoader",
"pos": {
"0": -26,
"1": 400
},
"size": {
"0": 451.30548095703125,
"1": 82
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"name": "clip",
"type": "CLIP",
"links": [
54,
56
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CLIPLoader"
},
"widgets_values": [
"t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors",
"sd3"
]
},
{
"id": 1,
"type": "DownloadAndLoadCogVideoModel",
"pos": {
"0": 642,
"1": 90
},
"size": {
"0": 315,
"1": 194
},
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"name": "pab_config",
"type": "PAB_CONFIG",
"link": null
},
{
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"link": null
},
{
"name": "lora",
"type": "COGLORA",
"link": null
"link": 65
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"name": "conditioning",
"type": "CONDITIONING",
"links": [
60
68
],
"slot_index": 0,
"shape": 3
},
{
"name": "clip",
"type": "CLIP",
"links": null
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoModel"
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"THUDM/CogVideoX-5b",
"bf16",
"disabled",
"disabled",
false
"",
1,
true
]
},
{
"id": 11,
"type": "CogVideoDecode",
"pos": {
"0": 1051,
"1": 748
"0": 1416,
"1": 40
},
"size": {
"0": 300.396484375,
"1": 198
},
"flags": {},
"order": 5,
"order": 6,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 63
"name": "vae",
"type": "VAE",
"link": 71
},
{
"name": "samples",
"type": "LATENT",
"link": 64
"link": 69
}
],
"outputs": [
@ -293,83 +147,297 @@
]
},
{
"id": 34,
"type": "CogVideoSampler",
"id": 36,
"type": "DownloadAndLoadCogVideoModel",
"pos": {
"0": 1041,
"1": 342
"0": 645,
"1": 17
},
"size": {
"0": 315.8404846191406,
"1": 358
"0": 315,
"1": 218
},
"flags": {},
"order": 4,
"order": 0,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 60
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"link": null,
"shape": 7
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 61
"name": "lora",
"type": "COGLORA",
"link": null,
"shape": 7
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 62
},
{
"name": "samples",
"type": "LATENT",
"link": null
},
{
"name": "image_cond_latents",
"type": "LATENT",
"link": null
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null
"name": "compile_args",
"type": "COMPILEARGS",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"name": "model",
"type": "COGVIDEOMODEL",
"links": [
63
70
]
},
{
"name": "vae",
"type": "VAE",
"links": [
71
],
"slot_index": 1
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoModel"
},
"widgets_values": [
"THUDM/CogVideoX-5b",
"bf16",
"disabled",
false,
"sdpa",
"main_device"
]
},
{
"id": 20,
"type": "CLIPLoader",
"pos": {
"0": 5,
"1": 308
},
"size": {
"0": 451.30548095703125,
"1": 82
},
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
54
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CLIPLoader"
},
"widgets_values": [
"t5\\google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors",
"sd3"
]
},
{
"id": 37,
"type": "EmptyLatentImage",
"pos": {
"0": 643,
"1": 827
},
"size": {
"0": 315,
"1": 106
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
72
]
}
],
"properties": {
"Node name for S&R": "EmptyLatentImage"
},
"widgets_values": [
720,
480,
1
]
},
{
"id": 35,
"type": "CogVideoSampler",
"pos": {
"0": 1042,
"1": 291
},
"size": [
330,
574
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "COGVIDEOMODEL",
"link": 70
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 67
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 68
},
{
"name": "samples",
"type": "LATENT",
"link": 72,
"shape": 7
},
{
"name": "image_cond_latents",
"type": "LATENT",
"link": null,
"shape": 7
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null,
"shape": 7
},
{
"name": "controlnet",
"type": "COGVIDECONTROLNET",
"link": null,
"shape": 7
},
{
"name": "tora_trajectory",
"type": "TORAFEATURES",
"link": null,
"shape": 7
},
{
"name": "fastercache",
"type": "FASTERCACHEARGS",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "samples",
"type": "LATENT",
"links": [
64
],
"shape": 3
69
]
}
],
"properties": {
"Node name for S&R": "CogVideoSampler"
},
"widgets_values": [
480,
720,
49,
50,
6,
806286757407563,
0,
"fixed",
"DPM++",
"CogVideoXDDIM",
1
]
},
{
"id": 33,
"type": "VHS_VideoCombine",
"pos": {
"0": 1767,
"1": 39
},
"size": [
778.7022705078125,
829.801513671875
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 59
},
{
"name": "audio",
"type": "AUDIO",
"link": null,
"shape": 7
},
{
"name": "meta_batch",
"type": "VHS_BatchManager",
"link": null,
"shape": 7
},
{
"name": "vae",
"type": "VAE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 8,
"loop_count": 0,
"filename_prefix": "CogVideoX5B-T2V",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"pingpong": false,
"save_output": false,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "CogVideoX5B_00001.mp4",
"subfolder": "",
"type": "temp",
"format": "video/h264-mp4",
"frame_rate": 8
},
"muted": false
}
}
}
],
"links": [
@ -381,14 +449,6 @@
0,
"CLIP"
],
[
56,
20,
0,
31,
0,
"CLIP"
],
[
59,
11,
@ -398,43 +458,59 @@
"IMAGE"
],
[
60,
65,
30,
1,
31,
0,
34,
0,
"COGVIDEOPIPE"
"CLIP"
],
[
61,
67,
30,
0,
34,
35,
1,
"CONDITIONING"
],
[
62,
68,
31,
0,
34,
35,
2,
"CONDITIONING"
],
[
63,
34,
69,
35,
0,
11,
0,
"COGVIDEOPIPE"
1,
"LATENT"
],
[
64,
34,
70,
36,
0,
35,
0,
"COGVIDEOMODEL"
],
[
71,
36,
1,
11,
1,
0,
"VAE"
],
[
72,
37,
0,
35,
3,
"LATENT"
]
],
@ -442,10 +518,10 @@
"config": {},
"extra": {
"ds": {
"scale": 0.6934334949442514,
"scale": 0.7627768444387061,
"offset": [
-24.154349208343916,
155.20539218330134
734.1791945221892,
237.29437844909364
]
}
},

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +1,247 @@
{
"last_node_id": 51,
"last_link_id": 114,
"last_node_id": 64,
"last_link_id": 149,
"nodes": [
{
"id": 63,
"type": "CogVideoSampler",
"pos": {
"0": 1142,
"1": 74
},
"size": {
"0": 330,
"1": 574
},
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "COGVIDEOMODEL",
"link": 144
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 145
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 146
},
{
"name": "samples",
"type": "LATENT",
"link": null,
"shape": 7
},
{
"name": "image_cond_latents",
"type": "LATENT",
"link": 147,
"shape": 7
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null,
"shape": 7
},
{
"name": "controlnet",
"type": "COGVIDECONTROLNET",
"link": null,
"shape": 7
},
{
"name": "tora_trajectory",
"type": "TORAFEATURES",
"link": null,
"shape": 7
},
{
"name": "fastercache",
"type": "FASTERCACHEARGS",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "samples",
"type": "LATENT",
"links": [
148
]
}
],
"properties": {
"Node name for S&R": "CogVideoSampler"
},
"widgets_values": [
49,
25,
6,
0,
"fixed",
"CogVideoXDDIM",
1
]
},
{
"id": 62,
"type": "CogVideoImageEncode",
"pos": {
"0": 1149,
"1": 711
},
"size": {
"0": 315,
"1": 122
},
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "vae",
"type": "VAE",
"link": 141
},
{
"name": "start_image",
"type": "IMAGE",
"link": 142
},
{
"name": "end_image",
"type": "IMAGE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "samples",
"type": "LATENT",
"links": [
147
]
}
],
"properties": {
"Node name for S&R": "CogVideoImageEncode"
},
"widgets_values": [
false,
0
]
},
{
"id": 30,
"type": "CogVideoTextEncode",
"pos": {
"0": 493,
"1": 303
},
"size": {
"0": 471.90142822265625,
"1": 168.08047485351562
},
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 54
}
],
"outputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"links": [
145
],
"slot_index": 0,
"shape": 3
},
{
"name": "clip",
"type": "CLIP",
"links": [
149
],
"slot_index": 1
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"a majestic stag is grazing in an enhanced forest, basking in the setting sun filtered by the trees",
1,
false
]
},
{
"id": 36,
"type": "LoadImage",
"pos": {
"0": 335,
"1": 731
},
"size": {
"0": 402.06353759765625,
"1": 396.6225891113281
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
71
],
"slot_index": 0,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"sd3stag.png",
"image"
]
},
{
"id": 20,
"type": "CLIPLoader",
"pos": {
"0": -26,
"1": 400
"0": -2,
"1": 304
},
"size": {
"0": 451.30548095703125,
"1": 82
},
"flags": {},
"order": 0,
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
@ -37,48 +264,49 @@
]
},
{
"id": 31,
"type": "CogVideoTextEncode",
"id": 60,
"type": "CogVideoDecode",
"pos": {
"0": 497,
"1": 520
"0": 1523,
"1": -6
},
"size": {
"0": 463.01251220703125,
"1": 144
"0": 315,
"1": 198
},
"flags": {},
"order": 5,
"order": 9,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 108
"name": "vae",
"type": "VAE",
"link": 132
},
{
"name": "samples",
"type": "LATENT",
"link": 148
}
],
"outputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"name": "images",
"type": "IMAGE",
"links": [
111
],
"slot_index": 0,
"shape": 3
},
{
"name": "clip",
"type": "CLIP",
"links": null
134
]
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
"Node name for S&R": "CogVideoDecode"
},
"widgets_values": [
"The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. ",
1,
true,
240,
360,
0.2,
0.2,
true
]
},
@ -86,21 +314,21 @@
"id": 44,
"type": "VHS_VideoCombine",
"pos": {
"0": 1842,
"1": 345
"0": 1884,
"1": -6
},
"size": [
855.81494140625,
881.2099609375
605.3909912109375,
654.5737362132353
],
"flags": {},
"order": 8,
"order": 10,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 97
"link": 134
},
{
"name": "audio",
@ -133,9 +361,9 @@
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 16,
"frame_rate": 8,
"loop_count": 0,
"filename_prefix": "CogVideoX_Fun",
"filename_prefix": "CogVideoX-I2V",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
@ -146,62 +374,22 @@
"hidden": false,
"paused": false,
"params": {
"filename": "CogVideoX_Fun_00003.mp4",
"filename": "CogVideoX-I2V_00004.mp4",
"subfolder": "",
"type": "temp",
"format": "video/h264-mp4",
"frame_rate": 16
"frame_rate": 8
},
"muted": false
}
}
},
{
"id": 36,
"type": "LoadImage",
"pos": {
"0": 227,
"1": 700
},
"size": {
"0": 391.3421325683594,
"1": 456.8497009277344
},
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
71
],
"slot_index": 0,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"sd3stag.png",
"image"
]
},
{
"id": 37,
"type": "ImageResizeKJ",
"pos": {
"0": 688,
"1": 708
"0": 784,
"1": 731
},
"size": {
"0": 315,
@ -244,7 +432,7 @@
"name": "IMAGE",
"type": "IMAGE",
"links": [
112
142
],
"slot_index": 0,
"shape": 3
@ -266,10 +454,10 @@
"Node name for S&R": "ImageResizeKJ"
},
"widgets_values": [
720,
480,
1360,
768,
"lanczos",
true,
false,
16,
0,
0,
@ -277,73 +465,24 @@
]
},
{
"id": 11,
"type": "CogVideoDecode",
"id": 31,
"type": "CogVideoTextEncode",
"pos": {
"0": 1477,
"1": 344
"0": 497,
"1": 520
},
"size": {
"0": 300.396484375,
"1": 198
"0": 463.01251220703125,
"1": 144
},
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 113
},
{
"name": "samples",
"type": "LATENT",
"link": 114
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"links": [
97
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "CogVideoDecode"
},
"widgets_values": [
true,
240,
360,
0.2,
0.2,
true
]
},
{
"id": 30,
"type": "CogVideoTextEncode",
"pos": {
"0": 493,
"1": 303
},
"size": {
"0": 471.90142822265625,
"1": 168.08047485351562
},
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 54
"link": 149
}
],
"outputs": [
@ -351,7 +490,7 @@
"name": "conditioning",
"type": "CONDITIONING",
"links": [
110
146
],
"slot_index": 0,
"shape": 3
@ -359,169 +498,128 @@
{
"name": "clip",
"type": "CLIP",
"links": [
108
],
"slot_index": 1
"links": null
}
],
"properties": {
"Node name for S&R": "CogVideoTextEncode"
},
"widgets_values": [
"majestic stag grazing in a forest and basking in the setting sun",
"",
1,
false
true
]
},
{
"id": 51,
"type": "CogVideoXFunSampler",
"id": 59,
"type": "DownloadAndLoadCogVideoModel",
"pos": {
"0": 1058,
"1": 345
"0": 622,
"1": -25
},
"size": {
"0": 367.79998779296875,
"1": 434
},
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "pipeline",
"type": "COGVIDEOPIPE",
"link": 109
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 110
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 111
},
{
"name": "start_img",
"type": "IMAGE",
"link": 112,
"shape": 7
},
{
"name": "end_img",
"type": "IMAGE",
"link": null,
"shape": 7
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null,
"shape": 7
},
{
"name": "tora_trajectory",
"type": "TORAFEATURES",
"link": null,
"shape": 7
},
{
"name": "fastercache",
"type": "FASTERCACHEARGS",
"link": null,
"shape": 7
},
{
"name": "vid2vid_images",
"type": "IMAGE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"links": [
113
]
},
{
"name": "samples",
"type": "LATENT",
"links": [
114
]
}
],
"properties": {
"Node name for S&R": "CogVideoXFunSampler"
},
"widgets_values": [
49,
720,
480,
43,
"randomize",
50,
6,
"DDIM",
0.0563,
1
]
},
{
"id": 48,
"type": "DownloadAndLoadCogVideoGGUFModel",
"pos": {
"0": 585,
"1": 34
},
"size": {
"0": 378,
"1": 198
"0": 315,
"1": 218
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "pab_config",
"type": "PAB_CONFIG",
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"link": null,
"shape": 7
},
{
"name": "block_edit",
"type": "TRANSFORMERBLOCKS",
"name": "lora",
"type": "COGLORA",
"link": null,
"shape": 7
},
{
"name": "compile_args",
"type": "COMPILEARGS",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "cogvideo_pipe",
"type": "COGVIDEOPIPE",
"name": "model",
"type": "COGVIDEOMODEL",
"links": [
109
144
]
},
{
"name": "vae",
"type": "VAE",
"links": [
132,
141
],
"slot_index": 0,
"shape": 3
"slot_index": 1
}
],
"properties": {
"Node name for S&R": "DownloadAndLoadCogVideoGGUFModel"
"Node name for S&R": "DownloadAndLoadCogVideoModel"
},
"widgets_values": [
"CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors",
"kijai/CogVideoX-5b-1.5-I2V",
"bf16",
"disabled",
false,
"offload_device",
"sdpa",
"main_device"
]
},
{
"id": 64,
"type": "CogVideoImageEncodeFunInP",
"pos": {
"0": 1861.032958984375,
"1": 752.6453247070312
},
"size": {
"0": 380.4000244140625,
"1": 146
},
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "vae",
"type": "VAE",
"link": null
},
{
"name": "start_image",
"type": "IMAGE",
"link": null
},
{
"name": "end_image",
"type": "IMAGE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "image_cond_latents",
"type": "LATENT",
"links": null
}
],
"properties": {
"Node name for S&R": "CogVideoImageEncodeFunInP"
},
"widgets_values": [
49,
false,
"disabled"
0
]
}
],
@ -543,78 +641,94 @@
"IMAGE"
],
[
97,
11,
132,
59,
1,
60,
0,
"VAE"
],
[
134,
60,
0,
44,
0,
"IMAGE"
],
[
108,
141,
59,
1,
62,
0,
"VAE"
],
[
142,
37,
0,
62,
1,
"IMAGE"
],
[
144,
59,
0,
63,
0,
"COGVIDEOMODEL"
],
[
145,
30,
0,
63,
1,
"CONDITIONING"
],
[
146,
31,
0,
63,
2,
"CONDITIONING"
],
[
147,
62,
0,
63,
4,
"LATENT"
],
[
148,
63,
0,
60,
1,
"LATENT"
],
[
149,
30,
1,
31,
0,
"CLIP"
],
[
109,
48,
0,
51,
0,
"COGVIDEOPIPE"
],
[
110,
30,
0,
51,
1,
"CONDITIONING"
],
[
111,
31,
0,
51,
2,
"CONDITIONING"
],
[
112,
37,
0,
51,
3,
"IMAGE"
],
[
113,
51,
0,
11,
0,
"COGVIDEOPIPE"
],
[
114,
51,
1,
11,
1,
"LATENT"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.7513148009015784,
"scale": 0.8390545288825803,
"offset": [
724.7448506313632,
128.336592104936
351.5513339440394,
161.02862760095286
]
}
},

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -36,10 +36,12 @@ def fp8_linear_forward(cls, original_dtype, input):
else:
return cls.original_forward(input)
def convert_fp8_linear(module, original_dtype):
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules():
if isinstance(module, nn.Linear):
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
if not any(keyword in name for keyword in params_to_keep):
if isinstance(module, nn.Linear):
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -1,9 +1,41 @@
import os
import torch
import torch.nn as nn
import json
import folder_paths
import comfy.model_management as mm
from typing import Union
def patched_write_atomic(
path_: str,
content: Union[str, bytes],
make_dirs: bool = False,
encode_utf_8: bool = False,
) -> None:
# Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions
from pathlib import Path
import os
import shutil
import threading
assert isinstance(
content, (str, bytes)
), "Only strings and byte arrays can be saved in the cache"
path = Path(path_)
if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
write_mode = "w" if isinstance(content, str) else "wb"
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
f.write(content)
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
os.remove(tmp_path)
try:
import torch._inductor.codecache
torch._inductor.codecache.write_atomic = patched_write_atomic
except:
pass
import torch
import torch.nn as nn
from .utils import check_diffusers_version, remove_specific_blocks, log
check_diffusers_version()
@ -22,6 +54,10 @@ from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Con
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from .utils import remove_specific_blocks, log
from comfy.utils import load_torch_file
script_directory = os.path.dirname(os.path.abspath(__file__))
@ -61,7 +97,8 @@ class CogVideoLoraSelect:
cog_loras_list.append(cog_lora)
print(cog_loras_list)
return (cog_loras_list,)
#region DownloadAndLoadCogVideoModel
class DownloadAndLoadCogVideoModel:
@classmethod
def INPUT_TYPES(s):
@ -72,6 +109,8 @@ class DownloadAndLoadCogVideoModel:
"THUDM/CogVideoX-2b",
"THUDM/CogVideoX-5b",
"THUDM/CogVideoX-5b-I2V",
"kijai/CogVideoX-5b-1.5-T2V",
"kijai/CogVideoX-5b-1.5-I2V",
"bertjiazheng/KoolCogVideoX-5b",
"kijai/CogVideoX-Fun-2b",
"kijai/CogVideoX-Fun-5b",
@ -90,28 +129,33 @@ class DownloadAndLoadCogVideoModel:
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
),
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fastmode', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
}
}
RETURN_TYPES = ("COGVIDEOPIPE",)
RETURN_NAMES = ("cogvideo_pipe", )
RETURN_TYPES = ("COGVIDEOMODEL", "VAE",)
RETURN_NAMES = ("model", "vae", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None):
def loadmodel(self, model, precision, quantization="disabled", compile="disabled",
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
check_diffusers_version()
if precision == "fp16" and "1.5" in model:
raise ValueError("1.5 models do not currently work in fp16")
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
manual_offloading = True
transformer_load_device = device if load_device == "main_device" else offload_device
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
@ -134,6 +178,8 @@ class DownloadAndLoadCogVideoModel:
if not os.path.exists(base_path):
base_path = os.path.join(download_path, (model.split("/")[-1]))
download_path = base_path
subfolder = "transformer"
allow_patterns = ["*transformer*", "*scheduler*", "*vae*"]
elif "2b" in model:
if 'img2vid' in model:
@ -144,41 +190,44 @@ class DownloadAndLoadCogVideoModel:
base_path = os.path.join(download_path, "CogVideo2B")
download_path = base_path
repo_id = model
subfolder = "transformer"
allow_patterns = ["*transformer*", "*scheduler*", "*vae*"]
elif "1.5-T2V" in model or "1.5-I2V" in model:
base_path = os.path.join(download_path, "CogVideoX-5b-1.5")
download_path = base_path
subfolder = "transformer_T2V" if "1.5-T2V" in model else "transformer_I2V"
allow_patterns = [f"*{subfolder}*", "*vae*", "*scheduler*"]
repo_id = "kijai/CogVideoX-5b-1.5"
else:
base_path = os.path.join(download_path, (model.split("/")[-1]))
download_path = base_path
repo_id = model
subfolder = "transformer"
allow_patterns = ["*transformer*", "*scheduler*", "*vae*"]
if "2b" in model:
scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json')
else:
scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json')
if not os.path.exists(base_path) or not os.path.exists(os.path.join(base_path, "transformer")):
if not os.path.exists(base_path) or not os.path.exists(os.path.join(base_path, subfolder)):
log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=repo_id,
allow_patterns=allow_patterns,
ignore_patterns=["*text_encoder*", "*tokenizer*"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
# transformer
if "Fun" in model:
if pab_config is not None:
transformer = CogVideoXTransformer3DModelFunPAB.from_pretrained(base_path, subfolder="transformer")
else:
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder="transformer")
else:
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder="transformer")
else:
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer")
transformer = transformer.to(dtype).to(offload_device)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = transformer.to(dtype).to(transformer_load_device)
if "1.5" in model:
transformer.config.sample_height = 300
transformer.config.sample_width = 300
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
@ -190,26 +239,21 @@ class DownloadAndLoadCogVideoModel:
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
# VAE
if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
if "Pose" in model:
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
else:
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
if "cogvideox-2b-img2vid" in model:
pipe.input_with_padding = False
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
#pipeline
pipe = CogVideoXPipeline(
transformer,
scheduler,
dtype=dtype,
is_fun_inpaint=True if "fun" in model.lower() and "pose" not in model.lower() else False
)
if "cogvideox-2b-img2vid" in model:
pipe.input_with_padding = False
#LoRAs
if lora is not None:
from .lora_utils import merge_lora#, load_lora_into_transformer
if "fun" in model.lower():
for l in lora:
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
else:
try:
adapter_list = []
adapter_weights = []
for l in lora:
@ -246,46 +290,146 @@ class DownloadAndLoadCogVideoModel:
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
lora_scale = 1
dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
if any(item in lora[-1]["path"].lower() for item in dimension_loras):
lora_scale = lora_scale / lora_rank
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
except: #Fun trainer LoRAs are loaded differently
from .lora_utils import merge_lora
for l in lora:
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
if "fused" in attention_mode:
from diffusers.models.attention import Attention
transformer.fuse_qkv_projections = True
for module in transformer.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
transformer.attention_mode = attention_mode
if compile_args is not None:
pipe.transformer.to(memory_format=torch.channels_last)
#fp8
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode":
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
if "1.5" in model:
params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"})
for name, param in pipe.transformer.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
if quantization == "fp8_e4m3fn_fastmode":
from .fp8_optimization import convert_fp8_linear
if "1.5" in model:
params_to_keep.update({"ff"}) #otherwise NaNs
convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep)
# compilation
if compile == "torch":
pipe.transformer.to(memory_format=torch.channels_last)
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
else:
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
elif compile == "onediff":
from onediffx import compile_pipe
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
pipe = compile_pipe(
pipe,
backend="nexfort",
options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
ignores=["vae"],
fuse_qkv_projections=True if pab_config is None else False,
)
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if "torchao" in quantization:
try:
from torchao.quantization import (
quantize_,
fpx_weight_only,
float8_dynamic_activation_float8_weight,
int8_dynamic_activation_int8_weight
)
except:
raise ImportError("torchao is not installed, please install torchao to use fp8dq")
def filter_fn(module: nn.Module, fqn: str) -> bool:
target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models
if any(sub in fqn for sub in target_submodules):
return isinstance(module, nn.Linear)
return False
if "fp6" in quantization: #slower for some reason on 4090
quant_func = fpx_weight_only(3, 2)
elif "fp8dq" in quantization: #very fast on 4090 when compiled
quant_func = float8_dynamic_activation_float8_weight()
elif 'fp8dqrow' in quantization:
from torchao.quantization.quant_api import PerRow
quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow())
elif 'int8dq' in quantization:
quant_func = int8_dynamic_activation_int8_weight()
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
quantize_(block, quant_func, filter_fn=filter_fn)
manual_offloading = False # to disable manual .to(device) calls
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
manual_offloading = False
# CogVideoXBlock(
# (norm1): CogVideoXLayerNormZero(
# (silu): SiLU()
# (linear): Linear(in_features=512, out_features=18432, bias=True)
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
# )
# (attn1): Attention(
# (norm_q): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
# (norm_k): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
# (to_q): Linear(in_features=3072, out_features=3072, bias=True)
# (to_k): Linear(in_features=3072, out_features=3072, bias=True)
# (to_v): Linear(in_features=3072, out_features=3072, bias=True)
# (to_out): ModuleList(
# (0): Linear(in_features=3072, out_features=3072, bias=True)
# (1): Dropout(p=0.0, inplace=False)
# )
# )
# (norm2): CogVideoXLayerNormZero(
# (silu): SiLU()
# (linear): Linear(in_features=512, out_features=18432, bias=True)
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
# )
# (ff): FeedForward(
# (net): ModuleList(
# (0): GELU(
# (proj): Linear(in_features=3072, out_features=12288, bias=True)
# )
# (1): Dropout(p=0.0, inplace=False)
# (2): Linear(in_features=12288, out_features=3072, bias=True)
# (3): Dropout(p=0.0, inplace=False)
# )
# )
# )
# if compile == "onediff":
# from onediffx import compile_pipe
# os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
# pipe = compile_pipe(
# pipe,
# backend="nexfort",
# options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
# ignores=["vae"],
# fuse_qkv_projections= False,
# )
pipeline = {
"pipe": pipe,
"dtype": dtype,
"base_path": base_path,
"onediff": True if compile == "onediff" else False,
"cpu_offloading": enable_sequential_cpu_offload,
"manual_offloading": manual_offloading,
"scheduler_config": scheduler_config,
"model_name": model
"model_name": model,
}
return (pipeline,)
return (pipeline, vae)
#region GGUF
class DownloadAndLoadCogVideoGGUFModel:
@classmethod
def INPUT_TYPES(s):
@ -295,12 +439,12 @@ class DownloadAndLoadCogVideoGGUFModel:
[
"CogVideoX_5b_GGUF_Q4_0.safetensors",
"CogVideoX_5b_I2V_GGUF_Q4_0.safetensors",
"CogVideoX_5b_1_5_I2V_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_1_1_Pose_GGUF_Q4_0.safetensors",
"CogVideoX_5b_Interpolation_GGUF_Q4_0.safetensors",
"CogVideoX_5b_Tora_GGUF_Q4_0.safetensors",
],
),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
@ -309,20 +453,20 @@ class DownloadAndLoadCogVideoGGUFModel:
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
},
"optional": {
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
#"lora": ("COGLORA", {"default": None}),
"compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
}
}
RETURN_TYPES = ("COGVIDEOPIPE",)
RETURN_NAMES = ("cogvideo_pipe", )
RETURN_TYPES = ("COGVIDEOMODEL", "VAE",)
RETURN_NAMES = ("model", "vae",)
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None, compile="disabled"):
check_diffusers_version()
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
block_edit=None, compile="disabled", attention_mode="sdpa"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -358,7 +502,6 @@ class DownloadAndLoadCogVideoGGUFModel:
with open(transformer_path) as f:
transformer_config = json.load(f)
sd = load_torch_file(gguf_path)
from . import mz_gguf_loader
import importlib
@ -370,45 +513,45 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["in_channels"] = 32
else:
transformer_config["in_channels"] = 33
if pab_config is not None:
transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config)
else:
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
elif "I2V" in model or "Interpolation" in model:
transformer_config["in_channels"] = 32
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
else:
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
if "1_5" in model:
transformer_config["ofs_embed_dim"] = 512
transformer_config["use_learned_positional_embeddings"] = False
transformer_config["patch_size_t"] = 2
transformer_config["patch_bias"] = False
transformer_config["sample_height"] = 300
transformer_config["sample_width"] = 300
else:
transformer_config["in_channels"] = 16
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
else:
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
if "2b" in model:
for name, param in transformer.named_parameters():
if name != "pos_embedding":
param.data = param.data.to(torch.float8_e4m3fn)
else:
param.data = param.data.to(torch.float16)
else:
transformer.to(torch.float8_e4m3fn)
cast_dtype = torch.float16
elif "1_5" in model:
params_to_keep = {"norm1.linear.weight", "patch_embed", "time_embedding", "ofs_embedding", "norm_final", "norm_out", "proj_out"}
cast_dtype = torch.bfloat16
for name, param in transformer.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
else:
param.data = param.data.to(cast_dtype)
#for name, param in transformer.named_parameters():
# print(name, param.data.dtype)
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
if load_device == "offload_device":
transformer.to(offload_device)
else:
transformer.to(device)
transformer.attention_mode = attention_mode
if fp8_fastmode:
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"}
if "1.5" in model:
params_to_keep.update({"ff","norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"})
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, vae_dtype)
convert_fp8_linear(transformer, vae_dtype, params_to_keep=params_to_keep)
if compile == "torch":
# compilation
@ -435,22 +578,25 @@ class DownloadAndLoadCogVideoGGUFModel:
with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f:
vae_config = json.load(f)
#VAE
vae_sd = load_torch_file(vae_path)
if "fun" in model:
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
if "Pose" in model:
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
else:
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
else:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
del vae_sd
pipe = CogVideoXPipeline(transformer, scheduler, dtype=vae_dtype)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
sd = load_torch_file(gguf_path)
pipe.transformer = mz_gguf_loader.quantize_load_state_dict(pipe.transformer, sd, device="cpu")
del sd
if load_device == "offload_device":
pipe.transformer.to(offload_device)
else:
pipe.transformer.to(device)
pipeline = {
"pipe": pipe,
"dtype": vae_dtype,
@ -458,11 +604,289 @@ class DownloadAndLoadCogVideoGGUFModel:
"onediff": False,
"cpu_offloading": enable_sequential_cpu_offload,
"scheduler_config": scheduler_config,
"model_name": model
"model_name": model,
"manual_offloading": True,
}
return (pipeline, vae)
#region ModelLoader
class CogVideoXModelLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}),
"base_precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "optional quantization method"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
},
"optional": {
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
}
}
RETURN_TYPES = ("COGVIDEOMODEL",)
RETURN_NAMES = ("model", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
manual_offloading = True
transformer_load_device = device if load_device == "main_device" else offload_device
mm.soft_empty_cache()
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[base_precision]
model_path = folder_paths.get_full_path_or_raise("diffusion_models", model)
sd = load_torch_file(model_path, device=transformer_load_device)
model_type = ""
if sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2):
model_type = "fun_5b"
elif sd["patch_embed.proj.weight"].shape == (3072, 16, 2, 2):
model_type = "5b"
elif sd["patch_embed.proj.weight"].shape == (3072, 128):
model_type = "5b_1_5"
elif sd["patch_embed.proj.weight"].shape == (3072, 256):
model_type = "5b_I2V_1_5"
elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2):
model_type = "fun_2b"
elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2):
model_type = "2b"
elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2):
if "pos_embedding" in sd:
model_type = "fun_5b_pose"
else:
model_type = "I2V_5b"
else:
raise Exception("Selected model is not recognized")
log.info(f"Detected CogVideoX model type: {model_type}")
if "5b" in model_type:
scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json')
transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_5b.json')
elif "2b" in model_type:
scheduler_config_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json')
transformer_config_path = os.path.join(script_directory, 'configs', 'transformer_config_2b.json')
with open(transformer_config_path) as f:
transformer_config = json.load(f)
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]:
transformer_config["in_channels"] = 32
if "1_5" in model_type:
transformer_config["ofs_embed_dim"] = 512
elif "fun" in model_type:
transformer_config["in_channels"] = 33
else:
transformer_config["in_channels"] = 16
if "1_5" in model_type:
transformer_config["use_learned_positional_embeddings"] = False
transformer_config["patch_size_t"] = 2
transformer_config["patch_bias"] = False
transformer_config["sample_height"] = 300
transformer_config["sample_width"] = 300
with init_empty_weights():
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
#load weights
#params_to_keep = {}
log.info("Using accelerate to load and assign model weights to device...")
for name, param in transformer.named_parameters():
#dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name])
del sd
#scheduler
with open(scheduler_config_path) as f:
scheduler_config = json.load(f)
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config, subfolder="scheduler")
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
if "fused" in attention_mode:
from diffusers.models.attention import Attention
transformer.fuse_qkv_projections = True
for module in transformer.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
transformer.attention_mode = attention_mode
if "fun" in model_type:
if not "pose" in model_type:
raise NotImplementedError("Fun models besides pose are not supported with this loader yet")
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
else:
pipe = CogVideoXPipeline(transformer, scheduler, dtype=base_dtype)
else:
pipe = CogVideoXPipeline(transformer, scheduler, dtype=base_dtype)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
#LoRAs
if lora is not None:
from .lora_utils import merge_lora#, load_lora_into_transformer
if "fun" in model.lower():
for l in lora:
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
else:
adapter_list = []
adapter_weights = []
for l in lora:
fuse = True if l["fuse_lora"] else False
lora_sd = load_torch_file(l["path"])
for key, val in lora_sd.items():
if "lora_B" in key:
lora_rank = val.shape[1]
break
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
adapter_name = l['path'].split("/")[-1].split(".")[0]
adapter_weight = l['strength']
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
#transformer = load_lora_into_transformer(lora, transformer)
adapter_list.append(adapter_name)
adapter_weights.append(adapter_weight)
for l in lora:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
if fuse:
lora_scale = 1
dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
if any(item in lora[-1]["path"].lower() for item in dimension_loras):
lora_scale = lora_scale / lora_rank
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
if compile_args is not None:
pipe.transformer.to(memory_format=torch.channels_last)
#quantization
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast":
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
if "1.5" in model:
params_to_keep.update({"norm1.linear.weight", "ofs_embedding", "norm_final", "norm_out", "proj_out"})
for name, param in pipe.transformer.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
if quantization == "fp8_e4m3fn_fast":
from .fp8_optimization import convert_fp8_linear
if "1.5" in model:
params_to_keep.update({"ff"}) #otherwise NaNs
convert_fp8_linear(pipe.transformer, base_dtype, params_to_keep=params_to_keep)
#compile
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if "torchao" in quantization:
try:
from torchao.quantization import (
quantize_,
fpx_weight_only,
float8_dynamic_activation_float8_weight,
int8_dynamic_activation_int8_weight
)
except:
raise ImportError("torchao is not installed, please install torchao to use fp8dq")
def filter_fn(module: nn.Module, fqn: str) -> bool:
target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models
if any(sub in fqn for sub in target_submodules):
return isinstance(module, nn.Linear)
return False
if "fp6" in quantization: #slower for some reason on 4090
quant_func = fpx_weight_only(3, 2)
elif "fp8dq" in quantization: #very fast on 4090 when compiled
quant_func = float8_dynamic_activation_float8_weight()
elif 'fp8dqrow' in quantization:
from torchao.quantization.quant_api import PerRow
quant_func = float8_dynamic_activation_float8_weight(granularity=PerRow())
elif 'int8dq' in quantization:
quant_func = int8_dynamic_activation_int8_weight()
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
quantize_(block, quant_func, filter_fn=filter_fn)
manual_offloading = False # to disable manual .to(device) calls
log.info(f"Quantized transformer blocks to {quantization}")
# if load_device == "offload_device":
# pipe.transformer.to(offload_device)
# else:
# pipe.transformer.to(device)
pipeline = {
"pipe": pipe,
"dtype": base_dtype,
"base_path": model,
"onediff": False,
"cpu_offloading": enable_sequential_cpu_offload,
"scheduler_config": scheduler_config,
"model_name": model,
"manual_offloading": manual_offloading,
}
return (pipeline,)
#region VAE
class CogVideoXVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
},
"optional": {
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16"}
),
}
}
RETURN_TYPES = ("VAE",)
RETURN_NAMES = ("vae", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'"
def loadmodel(self, model_name, precision):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f:
vae_config = json.load(f)
model_path = folder_paths.get_full_path("vae", model_name)
vae_sd = load_torch_file(model_path)
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device)
vae.load_state_dict(vae_sd)
return (vae,)
#region Tora
class DownloadAndLoadToraModel:
@classmethod
def INPUT_TYPES(s):
@ -483,9 +907,6 @@ class DownloadAndLoadToraModel:
DESCRIPTION = "Downloads and loads the the Tora model from Huggingface to 'ComfyUI/models/CogVideo/CogVideoX-5b-Tora'"
def loadmodel(self, model):
check_diffusers_version()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
@ -570,7 +991,7 @@ class DownloadAndLoadToraModel:
}
return (toramodel,)
#region controlnet
class DownloadAndLoadCogVideoControlNet:
@classmethod
def INPUT_TYPES(s):
@ -625,6 +1046,8 @@ NODE_CLASS_MAPPINGS = {
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
"CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoXVAELoader": CogVideoXVAELoader,
"CogVideoXModelLoader": CogVideoXModelLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -632,4 +1055,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
"DownloadAndLoadToraModel": "(Down)load Tora Model",
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoXVAELoader": "CogVideoX VAE Loader",
"CogVideoXModelLoader": "CogVideoX Model Loader",
}

View File

@ -19,17 +19,21 @@ class quantize_lazy_load():
def quantize_load_state_dict(model, state_dict, device="cpu"):
Q4_0_qkey = []
quant_keys = []
for key in state_dict.keys():
if key.endswith(".Q4_0_qweight"):
Q4_0_qkey.append(key.replace(".Q4_0_qweight", ""))
quant_keys.append(key.replace(".Q4_0_qweight", ""))
qtype = "Q4_0"
elif key.endswith(".Q8_0_qweight"):
quant_keys.append(key.replace(".Q8_0_qweight", ""))
qtype = "Q8_0"
for name, module in model.named_modules():
if name in Q4_0_qkey:
if name in quant_keys:
q_linear = WQLinear_GGUF.from_linear(
linear=module,
device=device,
qtype="Q4_0",
qtype=qtype,
)
set_op_by_name(model, name, q_linear)
@ -117,14 +121,14 @@ class WQLinear_GGUF(nn.Module):
@torch.no_grad()
def forward(self, x):
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
if self.qtype == "Q4_0":
x = F.linear(x, dequantize_blocks_Q4_0(
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype)
elif self.qtype == "Q8_0":
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
else:
raise ValueError(f"Unknown qtype: {self.qtype}")
return x
return F.linear(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
def split_block_dims(blocks, *args):
@ -153,6 +157,7 @@ def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]:
GGML_QUANT_SIZES = {
"Q4_0": (32, 2 + 16),
"Q8_0": (32, 2 + 32),
}
@ -186,3 +191,31 @@ def dequantize_blocks_Q4_0(data, dtype=torch.float16):
)).to(dtype)
return out
def dequantize_blocks_Q8_0(data, dtype=torch.float16):
block_size, type_size = GGML_QUANT_SIZES["Q8_0"]
data = data.to(torch.uint8)
shape = data.shape
rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = data.reshape((n_blocks, type_size))
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(torch.float32)
qs = qs.view(torch.int8).to(torch.float32)
out = (d * qs)
out = out.reshape(quant_shape_from_byte_shape(
shape,
qtype="Q8_0",
)).to(dtype)
return out

934
nodes.py

File diff suppressed because it is too large Load Diff

View File

@ -17,27 +17,23 @@ import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import math
from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
#from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.loaders import CogVideoXLoraLoaderMixin
from .embeddings import get_3d_rotary_pos_embed
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
from comfy.utils import ProgressBar
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from .videosys.core.pipeline import VideoSysPipeline
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
from .videosys.core.pab_mgr import set_pab_manager
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
@ -114,7 +110,7 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
@ -122,15 +118,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogVideoX uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
@ -142,33 +129,25 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
def __init__(
self,
vae: AutoencoderKLCogVideoX,
transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB],
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
original_mask = None,
pab_config = None
dtype: torch.dtype = torch.bfloat16,
is_fun_inpaint: bool = False,
):
super().__init__()
self.register_modules(
vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.original_mask = original_mask
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
if pab_config is not None:
set_pab_manager(pab_config)
self.register_modules(transformer=transformer, scheduler=scheduler)
self.vae_scale_factor_spatial = 8
self.vae_scale_factor_temporal = 4
self.vae_latent_channels = 16
self.vae_dtype = dtype
self.is_fun_inpaint = is_fun_inpaint
self.input_with_padding = True
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength,
self, batch_size, num_channels_latents, num_frames, height, width, device, generator, timesteps, denoise_strength,
num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None
):
shape = (
@ -178,14 +157,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype)
noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae_dtype)
if freenoise:
print("Applying FreeNoise")
logger.info("Applying FreeNoise")
# code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
video_length = num_frames // 4
delta = context_size - context_overlap
@ -225,20 +200,20 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
latent_timestep = timesteps[:1]
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
frames_needed = noise.shape[1]
current_frames = latents.shape[1]
if frames_needed > current_frames:
repeat_factor = frames_needed // current_frames
repeat_factor = frames_needed - current_frames
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
latents = torch.cat((latents, additional_frame), dim=1)
latents = torch.cat((additional_frame, latents), dim=1)
self.additional_frames = repeat_factor
elif frames_needed < current_frames:
latents = latents[:, :frames_needed, :, :, :]
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
latents = self.scheduler.add_noise(latents, noise.to(device), latent_timestep)
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
return latents, timesteps, noise
return latents, timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@ -286,29 +261,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps.to(device), num_inference_steps - t_start
def _gaussian_weights(self, t_tile_length, t_batch_size):
from numpy import pi, exp, sqrt
var = 0.01
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
weights = torch.tensor(t_probs)
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
@ -316,34 +268,41 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
width: int,
num_frames: int,
device: torch.device,
start_frame: int = None,
end_frame: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
if start_frame is not None:
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
if p_t is None:
# CogVideoX 1.0 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
@ -370,16 +329,15 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
height: int = 480,
width: int = 720,
num_frames: int = 48,
t_tile_length: int = 12,
t_tile_overlap: int = 4,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
denoise_strength: float = 1.0,
num_videos_per_prompt: int = 1,
sigmas: Optional[List[float]] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
fun_mask: Optional[torch.Tensor] = None,
image_cond_latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
@ -419,8 +377,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@ -436,15 +392,13 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
#assert (
# num_frames <= 48 and num_frames % fps == 0 and fps == 8
#), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
self.num_frames = num_frames
# 1. Check inputs. Raise error if not correct
self.check_inputs(
height,
@ -462,33 +416,42 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
do_classifier_free_guidance = guidance_scale[0] > 1.0
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_embeds = prompt_embeds.to(self.vae.dtype)
prompt_embeds = prompt_embeds.to(self.vae_dtype)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
if sigmas is None:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, sigmas=sigmas, device=device)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.vae.config.latent_channels
latent_channels = self.vae_latent_channels
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
if latents is None and num_frames == t_tile_length:
num_frames += 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = getattr(self.transformer.config, "patch_size_t", None)
if patch_size_t is None:
self.transformer.config.patch_size_t = None
ofs_embed_dim = getattr(self.transformer.config, "ofs_embed_dim", None)
if ofs_embed_dim is None:
self.transformer.config.ofs_embed_dim = None
if self.original_mask is not None:
image_latents = latents
original_image_latents = image_latents
self.additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
self.additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += self.additional_frames * self.vae_scale_factor_temporal
latents, timesteps, noise = self.prepare_latents(
batch_size * num_videos_per_prompt,
latents, timesteps = self.prepare_latents(
batch_size,
latent_channels,
num_frames,
height,
width,
self.vae.dtype,
device,
generator,
timesteps,
@ -499,95 +462,82 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
context_overlap=context_overlap,
freenoise=freenoise,
)
latents = latents.to(self.vae.dtype)
#print("latents", latents.shape)
latents = latents.to(self.vae_dtype)
if self.is_fun_inpaint and fun_mask is None: # For FUN inpaint vid2vid, we need to mask all the latents
fun_mask = torch.zeros_like(latents[:, :, :1, :, :], device=latents.device, dtype=latents.dtype)
fun_masked_video_latents = torch.zeros_like(latents, device=latents.device, dtype=latents.dtype)
# 5.5.
if image_cond_latents is not None:
if image_cond_latents.shape[1] > 1:
if image_cond_latents.shape[1] == 2:
logger.info("More than one image conditioning frame received, interpolating")
padding_shape = (
batch_size,
(latents.shape[1] - 2),
self.vae.config.latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
batch_size,
(latents.shape[1] - 2),
self.vae_latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
if self.transformer.config.patch_size_t is not None:
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
else:
elif image_cond_latents.shape[1] == 1:
logger.info("Only one image conditioning frame received, img2vid")
if self.input_with_padding:
padding_shape = (
batch_size,
(latents.shape[1] - 1),
self.vae.config.latent_channels,
self.vae_latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
# Select the first frame along the second dimension
if self.transformer.config.patch_size_t is not None:
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
else:
image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)
else:
logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames")
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# masks
if self.original_mask is not None:
mask = self.original_mask.to(device)
logger.info(f"self.original_mask: {self.original_mask.shape}")
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
if mask.shape[0] != latents.shape[1]:
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
else:
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
logger.info(f"latents: {latents.shape}")
logger.info(f"mask: {mask.shape}")
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
comfy_pbar = ProgressBar(num_inference_steps)
# 8. context schedule and temporal tiling
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames
t_tile_overlap = context_overlap
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
use_temporal_tiling = True
logger.info("Temporal tiling enabled")
elif context_schedule is not None:
# 7. context schedule
if context_schedule is not None:
if image_cond_latents is not None:
raise NotImplementedError("Context schedule not currently supported with image conditioning")
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
from .cogvideox_fun.context import get_context_scheduler
from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
#todo ofs embeds?
else:
use_temporal_tiling = False
use_context_schedule = False
logger.info("Temporal tiling and context schedule disabled")
# 7. Create rotary embeds if required
logger.info("Context schedule disabled")
# 7.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 7.6. Create ofs embeds if required
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
# 9. Controlnet
#8. Controlnet
if controlnet is not None:
self.controlnet = controlnet["control_model"].to(device)
if self.transformer.dtype == torch.float8_e4m3fn:
@ -616,112 +566,26 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if tora is not None:
trajectory_length = tora["video_flow_features"].shape[1]
logger.info(f"Tora trajectory length: {trajectory_length}")
if trajectory_length != latents.shape[1]:
raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
#if trajectory_length != latents.shape[1]:
# raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)
# 10. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
from .latent_preview import prepare_callback
callback = prepare_callback(self.transformer, num_inference_steps)
# 9. Denoising loop
comfy_pbar = ProgressBar(len(timesteps))
with self.progress_bar(total=len(timesteps)) as progress_bar:
old_pred_original_sample = None # for DPM-solver++
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# =====================================================
grid_ts = 0
cur_t = 0
while cur_t < latents.shape[1]:
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
grid_ts += 1
all_t = latents.shape[1]
latents_all_list = []
# =====================================================
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
#t_input = t[None].to(device)
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input_tile,
encoder_hidden_states=prompt_embeds,
timestep=t_input,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents_all_list.append(latents_tile)
# ==========================================
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
# Add each tile contribution to overall latents
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
latents_all /= contributors
latents = latents_all
#print("latents",latents.shape)
# start diff diff
if i < len(timesteps) - 1 and self.original_mask is not None:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
)
mask = mask.to(latents)
ts_from = timesteps[0]
ts_to = timesteps[-1]
threshold = (t - ts_to) / (ts_from - ts_to)
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# ==========================================
elif use_context_schedule:
# region context schedule sampling
if use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
counter = torch.zeros_like(latent_model_input)
@ -819,7 +683,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
noise_pred /= counter
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
@ -839,14 +703,26 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# region sampling
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if image_cond_latents is not None:
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
if fun_mask is not None: #for fun img2vid and interpolation
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2)
latent_model_input = torch.cat([latent_model_input, masks_input], dim=2)
else:
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
else: # for Fun inpaint vid2vid
if fun_mask is not None:
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents
fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@ -866,9 +742,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
return_dict=False,
)[0]
if isinstance(controlnet_states, (tuple, list)):
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
controlnet_states = [x.to(dtype=self.vae_dtype) for x in controlnet_states]
else:
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
controlnet_states = controlnet_states.to(dtype=self.vae_dtype)
# predict noise model_output
@ -877,53 +753,43 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
ofs=ofs_emb,
return_dict=False,
controlnet_states=controlnet_states,
controlnet_weights=control_weights,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0]
noise_pred = noise_pred.float()
if isinstance(self.scheduler, CogVideoXDPMScheduler):
self._guidance_scale = 1 + guidance_scale * (
self._guidance_scale[i] = 1 + guidance_scale[i] * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self._guidance_scale[i] * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, t, latents.to(self.vae_dtype), **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents.to(self.vae.dtype),
latents.to(self.vae_dtype),
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# start diff diff
if i < len(timesteps) - 1 and self.original_mask is not None:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
)
mask = mask.to(latents)
ts_from = timesteps[0]
ts_to = timesteps[-1]
threshold = (t - ts_to) / (ts_from - ts_to)
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
if callback is not None:
callback(i, latents.detach()[-1], None, num_inference_steps)
else:
comfy_pbar.update(1)
# Offload all models

View File

@ -1,9 +1,9 @@
[project]
name = "comfyui-cogvideoxwrapper"
description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)"
version = "1.1.0"
version = "1.5.0"
license = {file = "LICENSE"}
dependencies = ["huggingface_hub", "diffusers>=0.30.1", "accelerate>=0.33.0"]
dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"]
[project.urls]
Repository = "https://github.com/kijai/ComfyUI-CogVideoXWrapper"

View File

@ -1,5 +1,32 @@
# WORK IN PROGRESS
## BREAKING Update8
This is big one, and unfortunately to do the necessary cleanup and refactoring this will break every old workflow as they are.
I apologize for the inconvenience, if I don't do this now I'll keep making it worse until maintaining becomes too much of a chore, so from my pov there was no choice.
*Please either use the new workflows or fix the nodes in your old ones before posting issue reports!*
Old version will be kept in a legacy branch, but not maintained
- Support CogVideoX 1.5 models
- Major code cleanup (it was bad, still isn't great, wip)
- Merge Fun -model functionality into main pipeline:
- All Fun specific nodes, besides image encode node for Fun -InP models are gone
- Main CogVideo Sampler works with Fun models
- DimensionX LoRAs now work with Fun models as well
- Remove width/height from the sampler widgets and detect from input instead, this meanst text2vid now requires using empty latents
- Separate VAE from the model, allow using fp32 VAE
- Add ability to load some of the non-GGUF models as single files (only few available for now: https://huggingface.co/Kijai/CogVideoX-comfy)
- Add some torchao quantizations as options
- Add interpolation as option for the main encode node, old interpolation specific node is gone
- torch.compile optimizations
- Remove PAB in favor of FasterCache and cleaner code
- other smaller things I forgot about at this point
For Fun -model based workflows it's more drastic change, for others migrating generally means re-setting many of the nodes.
## Update7
- Refactored the Fun version's sampler to accept any resolution, this should make it lot simpler to use with Tora. **BREAKS OLD WORKFLOWS**, old FunSampler nodes need to be remade.

View File

@ -3,3 +3,4 @@ diffusers>=0.31.0
accelerate>=0.33.0
einops
peft
opencv-python

View File

@ -1,5 +1,5 @@
import importlib.metadata
import torch
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
@ -19,4 +19,12 @@ def remove_specific_blocks(model, block_indices_to_remove):
new_blocks = [block for i, block in enumerate(transformer_blocks) if i not in block_indices_to_remove]
model.transformer_blocks = nn.ModuleList(new_blocks)
return model
return model
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
log.info(f"Allocated memory: {memory=:.3f} GB")
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")

View File

@ -1,621 +0,0 @@
# Adapted from CogVideo
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# CogVideo: https://github.com/THUDM/CogVideo
# diffusers: https://github.com/huggingface/diffusers
# --------------------------------------------------------
from typing import Any, Dict, Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import is_torch_version
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
from .core.pab_mgr import enable_pab, if_broadcast_spatial
from .modules.embeddings import apply_rotary_emb
#from .modules.embeddings import CogVideoXPatchEmbed
from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
try:
from sageattention import sageattn
SAGEATTN_IS_AVAVILABLE = True
except:
SAGEATTN_IS_AVAVILABLE = False
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
attn_heads = attn.heads
inner_dim = key.shape[-1]
head_dim = inner_dim // attn_heads
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
emb_len = image_rotary_emb[0].shape[0]
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAVILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
block_idx: int = 0,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# parallel
#self.attn1.parallel_manager = None
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# pab
self.attn_count = 0
self.last_attn = None
self.block_idx = block_idx
#@torch.compiler.disable()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep=None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
# attention
if enable_pab():
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
if enable_pab() and broadcast_attn:
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if enable_pab():
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
self.fuser_list = None
# parallel
#self.parallel_manager = None
# def enable_parallel(self, dp_size, sp_size, enable_cp):
# # update cfg parallel
# if enable_cp and sp_size % 2 == 0:
# sp_size = sp_size // 2
# cp_size = 2
# else:
# cp_size = 1
# self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
# for _, module in self.named_modules():
# if hasattr(module, "parallel_manager"):
# module.parallel_manager = self.parallel_manager
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
video_flow_features: Optional[torch.Tensor] = None,
):
# if self.parallel_manager.cp_size > 1:
# (
# hidden_states,
# encoder_hidden_states,
# timestep,
# timestep_cond,
# image_rotary_emb,
# ) = batch_func(
# partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
# hidden_states,
# encoder_hidden_states,
# timestep,
# timestep_cond,
# image_rotary_emb,
# )
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# if self.parallel_manager.sp_size > 1:
# set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
# hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
timestep=timesteps if enable_pab() else None,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
#if self.parallel_manager.sp_size > 1:
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
#if self.parallel_manager.cp_size > 1:
# output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -1,232 +0,0 @@
PAB_MANAGER = None
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
spatial_broadcast: bool = False,
spatial_threshold: list = None,
spatial_range: int = None,
temporal_broadcast: bool = False,
temporal_threshold: list = None,
temporal_range: int = None,
mlp_broadcast: bool = False,
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_range = temporal_range
self.mlp_broadcast = mlp_broadcast
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
self.mlp_temporal_outputs = {}
self.mlp_spatial_outputs = {}
class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
print(init_prompt)
def if_broadcast_cross(self, timestep: int, count: int):
if (
self.config.cross_broadcast
and (timestep is not None)
and (count % self.config.cross_range != 0)
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_temporal(self, timestep: int, count: int):
if (
self.config.temporal_broadcast
and (timestep is not None)
and (count % self.config.temporal_range != 0)
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
if (
self.config.spatial_broadcast
and (timestep is not None)
and (count % self.config.spatial_range != 0)
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
@staticmethod
def _is_t_in_skip_config(all_timesteps, timestep, config):
is_t_in_skip_config = False
skip_range = None
for key in config:
if key not in all_timesteps:
continue
index = all_timesteps.index(key)
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
if timestep in skip_range:
is_t_in_skip_config = True
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
break
return is_t_in_skip_config, skip_range
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not self.config.mlp_broadcast:
return False, None, False, None
if is_temporal:
cur_config = self.config.mlp_temporal_broadcast_config
else:
cur_config = self.config.mlp_spatial_broadcast_config
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
next_flag = False
if (
self.config.mlp_broadcast
and (timestep is not None)
and (timestep in cur_config)
and (block_idx in cur_config[timestep]["block"])
):
flag = False
next_flag = True
count = count + 1
elif (
self.config.mlp_broadcast
and (timestep is not None)
and (is_t_in_skip_config)
and (block_idx in cur_config[skip_range[0]]["block"])
):
flag = True
count = 0
else:
flag = False
return flag, count, next_flag, skip_range
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
if is_temporal:
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
else:
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
skip_start_t = skip_range[0]
if is_temporal:
skip_output = (
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_temporal_outputs is not None
else None
)
else:
skip_output = (
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_spatial_outputs is not None
else None
)
if skip_output is not None:
if timestep == skip_range[-1]:
# TODO: save memory
if is_temporal:
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
else:
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
else:
raise ValueError(
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
)
return skip_output
def get_spatial_mlp_outputs(self):
return self.config.mlp_spatial_outputs
def get_temporal_mlp_outputs(self):
return self.config.mlp_temporal_outputs
def set_pab_manager(config: PABConfig):
global PAB_MANAGER
PAB_MANAGER = PABManager(config)
def enable_pab():
if PAB_MANAGER is None:
return False
return (
PAB_MANAGER.config.cross_broadcast
or PAB_MANAGER.config.spatial_broadcast
or PAB_MANAGER.config.temporal_broadcast
)
def update_steps(steps: int):
if PAB_MANAGER is not None:
PAB_MANAGER.config.steps = steps
def if_broadcast_cross(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_cross(timestep, count)
def if_broadcast_temporal(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not enable_pab():
return False, count
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)

View File

@ -1,44 +0,0 @@
import inspect
from abc import abstractmethod
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
class VideoSysPipeline(DiffusionPipeline):
def __init__(self):
super().__init__()
@staticmethod
def set_eval_and_device(device: torch.device, *modules):
for module in modules:
module.eval()
module.to(device)
@abstractmethod
def generate(self, *args, **kwargs):
pass
def __call__(self, *args, **kwargs):
"""
In diffusers, it is a convention to call the pipeline object.
But in VideoSys, we will use the generate method for better prompt.
This is a wrapper for the generate method to support the diffusers usage.
"""
return self.generate(*args, **kwargs)
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
# modify: remove the config module from the expected modules
expected_modules = expected_modules - {"config"}
optional_names = list(optional_parameters)
for name in optional_names:
if name in cls._optional_components:
expected_modules.add(name)
optional_parameters.remove(name)
return expected_modules, optional_parameters

View File

@ -1,3 +0,0 @@
import torch.nn as nn
approx_gelu = lambda: nn.GELU(approximate="tanh")

View File

@ -1,71 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class CogVideoXDownsample3D(nn.Module):
# Todo: Wait for paper relase.
r"""
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `2`):
Stride of the convolution.
padding (`int`, defaults to `0`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x

View File

@ -1,308 +0,0 @@
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
class OpenSoraPatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs // s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
s_emb = self.mlp(s_freq)
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
@property
def dtype(self):
return next(self.parameters()).dtype
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Use for example in Lumina
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Use for example in Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)

View File

@ -1,85 +0,0 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
x = self.norm(x) * (1 + scale) + shift
return x

View File

@ -1,67 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class CogVideoXUpsample3D(nn.Module):
r"""
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `1`):
Stride of the convolution.
padding (`int`, defaults to `1`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = F.interpolate(x_first, scale_factor=2.0)
x_rest = F.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = F.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs

View File

@ -1,64 +0,0 @@
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
spatial_broadcast: bool = False,
spatial_threshold: list = None,
spatial_range: int = None,
temporal_broadcast: bool = False,
temporal_threshold: list = None,
temporal_range: int = None,
mlp_broadcast: bool = False,
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_range = temporal_range
self.mlp_broadcast = mlp_broadcast
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
self.mlp_temporal_outputs = {}
self.mlp_spatial_outputs = {}
class CogVideoXPABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
temporal_broadcast: bool = False,
temporal_threshold: list = [100, 850],
temporal_range: int = 4,
cross_broadcast: bool = False,
cross_threshold: list = [100, 850],
cross_range: int = 6,
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range
)