This commit is contained in:
kijai 2024-11-09 02:24:18 +02:00
parent 2074ba578e
commit e783951dad
3 changed files with 262 additions and 242 deletions

View File

@ -21,20 +21,18 @@ import torch.nn.functional as F
import numpy as np
from einops import rearrange
from functools import reduce
from operator import mul
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from diffusers.loaders import PeftAdapterMixin
from .embeddings import CogVideoX1_1PatchEmbed
from .embeddings import CogVideoXPatchEmbed
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -64,14 +62,6 @@ def fft(tensor):
return low_freq_fft, high_freq_fft
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@ -81,16 +71,7 @@ class CogVideoXAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def rotary(self, t, rope_args):
def reshape_freq(freqs):
freqs = freqs[: rope_args["T"], : rope_args["H"], : rope_args["W"]].contiguous()
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return t * freqs_cos + rotate_half(t) * freqs_sin
@torch.compiler.disable()
def __call__(
self,
@ -99,7 +80,6 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
rope_args: Optional[dict] = None
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@ -129,127 +109,118 @@ class CogVideoXAttnProcessor2_0:
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
self.freqs_cos = image_rotary_emb[0]
self.freqs_sin = image_rotary_emb[1]
print("rope args", rope_args) #{'T': 6, 'H': 30, 'W': 45, 'seq_length': 8775}
print("freqs_cos", self.freqs_cos.shape) #torch.Size([13, 30, 45, 64])
print("freqs_sin", self.freqs_sin.shape)
from diffusers.models.embeddings import apply_rotary_emb
#query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
query = torch.cat(
(query[:, :, : text_seq_length],
self.rotary(query[:, :, text_seq_length:],
rope_args)),
dim=2)
if not attn.is_cross_attention:
#key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
key = torch.cat(
(key[ :, :, : text_seq_length],
self.rotary(key[:, :, text_seq_length:],
rope_args)),
dim=2)
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
has_nan = torch.isnan(query).any()
if has_nan:
raise ValueError(f"query before rope has nan: {has_nan}")
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
#if SAGEATTN_IS_AVAILABLE:
# hidden_states = sageattn(query, key, value, is_causal=False)
#else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
has_nan = torch.isnan(hidden_states).any()
if has_nan:
raise ValueError(f"hs after scaled_dot_product_attention has nan: {has_nan}")
has_inf = torch.isinf(hidden_states).any()
if has_inf:
raise ValueError(f"hs after scaled_dot_product_attention has inf: {has_inf}")
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
has_nan = torch.isnan(hidden_states).any()
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
# class FusedCogVideoXAttnProcessor2_0:
# r"""
# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
# query and key vectors, but does not include spatial normalization.
# """
# def __init__(self):
# if not hasattr(F, "scaled_dot_product_attention"):
# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# @torch.compiler.disable()
# def __call__(
# self,
# attn: Attention,
# hidden_states: torch.Tensor,
# encoder_hidden_states: torch.Tensor,
# attention_mask: Optional[torch.Tensor] = None,
# image_rotary_emb: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# print("FusedCogVideoXAttnProcessor2_0")
# text_seq_length = encoder_hidden_states.size(1)
# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# batch_size, sequence_length, _ = (
# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# )
# if attention_mask is not None:
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# qkv = attn.to_qkv(hidden_states)
# split_size = qkv.shape[-1] // 3
# query, key, value = torch.split(qkv, split_size, dim=-1)
# inner_dim = key.shape[-1]
# head_dim = inner_dim // attn.heads
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# if attn.norm_q is not None:
# query = attn.norm_q(query)
# if attn.norm_k is not None:
# key = attn.norm_k(key)
# # Apply RoPE if needed
# if image_rotary_emb is not None:
# from diffusers.models.embeddings import apply_rotary_emb
# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
# if not attn.is_cross_attention:
# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
# if SAGEATTN_IS_AVAILABLE:
# hidden_states = sageattn(query, key, value, is_causal=False)
# else:
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# # linear proj
# hidden_states = attn.to_out[0](hidden_states)
# # dropout
# hidden_states = attn.to_out[1](hidden_states)
# encoder_hidden_states, hidden_states = hidden_states.split(
# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
# )
# return hidden_states, encoder_hidden_states
#region Blocks
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
@ -344,14 +315,14 @@ class CogVideoXBlock(nn.Module):
fuser=None,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
rope_args=None
fastercache_device="cuda:0"
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
@ -378,7 +349,7 @@ class CogVideoXBlock(nn.Module):
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, rope_args=rope_args
image_rotary_emb=image_rotary_emb
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
@ -386,10 +357,18 @@ class CogVideoXBlock(nn.Module):
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# has_nan = torch.isnan(hidden_states).any()
# if has_nan:
# raise ValueError(f"hs before norm2 has nan: {has_nan}")
# has_inf = torch.isinf(hidden_states).any()
# if has_inf:
# raise ValueError(f"hs before norm2 has inf: {has_inf}")
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
@ -404,7 +383,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
#region Transformer
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@ -479,6 +458,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
patch_size_t: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
@ -489,6 +469,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -501,12 +482,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
# 1. Patch embedding
self.patch_embed = CogVideoX1_1PatchEmbed(
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
#bias=True,
bias=patch_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
@ -550,7 +532,14 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels)
if patch_size_t is None:
# For CogVideox 1.0
output_dim = patch_size * patch_size * out_channels
else:
# For CogVideoX 1.5
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = nn.Linear(inner_dim, output_dim)
self.gradient_checkpointing = False
@ -626,44 +615,44 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
# def fuse_qkv_projections(self):
# """
# Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
# are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
# <Tip warning={true}>
This API is 🧪 experimental.
# This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
# </Tip>
# """
# self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# for _, attn_processor in self.attn_processors.items():
# if "Added" in str(attn_processor.__class__.__name__):
# raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
# self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# for module in self.modules():
# if isinstance(module, Attention):
# module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
# def unfuse_qkv_projections(self):
# """Disables the fused QKV projection if enabled.
<Tip warning={true}>
# <Tip warning={true}>
This API is 🧪 experimental.
# This API is 🧪 experimental.
</Tip>
# </Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# """
# if self.original_attn_processors is not None:
# self.set_attn_processor(self.original_attn_processors)
def forward(
self,
@ -678,9 +667,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
p = self.config.patch_size
print("p", p)
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
@ -691,25 +678,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# RoPE
seq_length = num_frames * height * width // reduce(mul, [p, p, p])
rope_T = num_frames // p
rope_H = height // p
rope_W = width // p
rope_args = {
"T": rope_T,
"H": rope_H,
"W": rope_W,
"seq_length": seq_length,
}
# 2. Patch embedding
p = self.config.patch_size
p_t = self.config.patch_size_t
# We know that the hidden states height and width will always be divisible by patch_size.
# But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames.
if p_t is not None:
remaining_frames = p_t - num_frames % p_t
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if self.use_fastercache:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
@ -754,8 +740,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
output = output[:, remaining_frames:]
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
@ -777,6 +770,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
output = torch.cat([output, recovered_uncond])
else:
for i, block in enumerate(self.transformer_blocks):
print("block", i)
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
@ -785,9 +779,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device,
rope_args=rope_args
fastercache_device = self.fastercache_device
)
has_nan = torch.isnan(hidden_states).any()
if has_nan:
raise ValueError(f"block output hidden_states has nan: {has_nan}")
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
@ -816,9 +812,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
output = output[:, remaining_frames:]
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple, Union
from typing import Tuple, Union, Optional
def get_1d_rotary_pos_embed(
dim: int,
@ -123,9 +123,9 @@ def get_3d_rotary_pos_embed(
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
#freqs = freqs.view(
# temporal_size * grid_size_h * grid_size_w, -1
#) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
@ -236,16 +236,18 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class CogVideoX1_1PatchEmbed(nn.Module):
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
patch_size_t: Optional[int] = None,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 81,
sample_frames: int = 49,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875,
@ -255,8 +257,8 @@ class CogVideoX1_1PatchEmbed(nn.Module):
) -> None:
super().__init__()
# Adjust patch_size to handle three dimensions
self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width)
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
@ -268,8 +270,15 @@ class CogVideoX1_1PatchEmbed(nn.Module):
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
# Use Linear layer for projection
self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim)
if patch_size_t is None:
# CogVideoX 1.0 checkpoints
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
else:
# CogVideoX 1.5 checkpoints
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings or use_learned_positional_embeddings:
@ -278,8 +287,8 @@ class CogVideoX1_1PatchEmbed(nn.Module):
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size[1]
post_patch_width = sample_width // self.patch_size[2]
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
@ -291,44 +300,46 @@ class CogVideoX1_1PatchEmbed(nn.Module):
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False)
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
return joint_pos_embedding
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
"""
r"""
Args:
text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim).
image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width).
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
first_frame = image_embeds[:, 0:1, :, :, :]
duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width)
# Copy the first frames, for t_patch
image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous()
image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1)
rope_patch_t = num_frames // self.patch_size[0]
rope_patch_h = height // self.patch_size[1]
rope_patch_w = width // self.patch_size[2]
batch_size, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.view(
batch,
rope_patch_t, self.patch_size[0],
rope_patch_h, self.patch_size[1],
rope_patch_w, self.patch_size[2],
channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1)
image_embeds = self.proj(image_embeds)
# Concatenate text and image embeddings
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
# Add positional embeddings if applicable
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
@ -339,9 +350,9 @@ class CogVideoX1_1PatchEmbed(nn.Module):
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
@ -350,4 +361,5 @@ class CogVideoX1_1PatchEmbed(nn.Module):
embeds = embeds + pos_embedding
return embeds
return embeds

View File

@ -21,6 +21,7 @@ import torch.nn.functional as F
import math
from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
@ -115,7 +116,7 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
@ -298,18 +299,18 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
# def fuse_qkv_projections(self) -> None:
# r"""Enables fused QKV projections."""
# self.fusing_transformer = True
# self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
# def unfuse_qkv_projections(self) -> None:
# r"""Disable QKV projection fusion if enabled."""
# if not self.fusing_transformer:
# logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
# else:
# self.transformer.unfuse_qkv_projections()
# self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
@ -322,8 +323,12 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
@ -332,7 +337,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
use_real=True,
)