This commit is contained in:
kijai 2024-11-09 02:24:18 +02:00
parent 2074ba578e
commit e783951dad
3 changed files with 262 additions and 242 deletions

View File

@ -21,20 +21,18 @@ import torch.nn.functional as F
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
from functools import reduce
from operator import mul
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import Attention, FeedForward from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from diffusers.loaders import PeftAdapterMixin from diffusers.loaders import PeftAdapterMixin
from .embeddings import CogVideoX1_1PatchEmbed from .embeddings import CogVideoXPatchEmbed
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -64,14 +62,6 @@ def fft(tensor):
return low_freq_fft, high_freq_fft return low_freq_fft, high_freq_fft
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class CogVideoXAttnProcessor2_0: class CogVideoXAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@ -81,16 +71,7 @@ class CogVideoXAttnProcessor2_0:
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def rotary(self, t, rope_args):
def reshape_freq(freqs):
freqs = freqs[: rope_args["T"], : rope_args["H"], : rope_args["W"]].contiguous()
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return t * freqs_cos + rotate_half(t) * freqs_sin
@torch.compiler.disable() @torch.compiler.disable()
def __call__( def __call__(
self, self,
@ -99,7 +80,6 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
rope_args: Optional[dict] = None
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
@ -129,127 +109,118 @@ class CogVideoXAttnProcessor2_0:
if attn.norm_k is not None: if attn.norm_k is not None:
key = attn.norm_k(key) key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
self.freqs_cos = image_rotary_emb[0]
self.freqs_sin = image_rotary_emb[1]
print("rope args", rope_args) #{'T': 6, 'H': 30, 'W': 45, 'seq_length': 8775}
print("freqs_cos", self.freqs_cos.shape) #torch.Size([13, 30, 45, 64])
print("freqs_sin", self.freqs_sin.shape)
from diffusers.models.embeddings import apply_rotary_emb
#query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
query = torch.cat(
(query[:, :, : text_seq_length],
self.rotary(query[:, :, text_seq_length:],
rope_args)),
dim=2)
if not attn.is_cross_attention:
#key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
key = torch.cat(
(key[ :, :, : text_seq_length],
self.rotary(key[:, :, text_seq_length:],
rope_args)),
dim=2)
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None: if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb from diffusers.models.embeddings import apply_rotary_emb
has_nan = torch.isnan(query).any()
if has_nan:
raise ValueError(f"query before rope has nan: {has_nan}")
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention: if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAILABLE: #if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False) # hidden_states = sageattn(query, key, value, is_causal=False)
else: #else:
hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
) )
has_nan = torch.isnan(hidden_states).any()
if has_nan:
raise ValueError(f"hs after scaled_dot_product_attention has nan: {has_nan}")
has_inf = torch.isinf(hidden_states).any()
if has_inf:
raise ValueError(f"hs after scaled_dot_product_attention has inf: {has_inf}")
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
has_nan = torch.isnan(hidden_states).any()
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split( encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
) )
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
# class FusedCogVideoXAttnProcessor2_0:
# r"""
# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
# query and key vectors, but does not include spatial normalization.
# """
# def __init__(self):
# if not hasattr(F, "scaled_dot_product_attention"):
# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# @torch.compiler.disable()
# def __call__(
# self,
# attn: Attention,
# hidden_states: torch.Tensor,
# encoder_hidden_states: torch.Tensor,
# attention_mask: Optional[torch.Tensor] = None,
# image_rotary_emb: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# print("FusedCogVideoXAttnProcessor2_0")
# text_seq_length = encoder_hidden_states.size(1)
# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# batch_size, sequence_length, _ = (
# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# )
# if attention_mask is not None:
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# qkv = attn.to_qkv(hidden_states)
# split_size = qkv.shape[-1] // 3
# query, key, value = torch.split(qkv, split_size, dim=-1)
# inner_dim = key.shape[-1]
# head_dim = inner_dim // attn.heads
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# if attn.norm_q is not None:
# query = attn.norm_q(query)
# if attn.norm_k is not None:
# key = attn.norm_k(key)
# # Apply RoPE if needed
# if image_rotary_emb is not None:
# from diffusers.models.embeddings import apply_rotary_emb
# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
# if not attn.is_cross_attention:
# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
# if SAGEATTN_IS_AVAILABLE:
# hidden_states = sageattn(query, key, value, is_causal=False)
# else:
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# # linear proj
# hidden_states = attn.to_out[0](hidden_states)
# # dropout
# hidden_states = attn.to_out[1](hidden_states)
# encoder_hidden_states, hidden_states = hidden_states.split(
# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
# )
# return hidden_states, encoder_hidden_states
#region Blocks
@maybe_allow_in_graph @maybe_allow_in_graph
class CogVideoXBlock(nn.Module): class CogVideoXBlock(nn.Module):
@ -344,14 +315,14 @@ class CogVideoXBlock(nn.Module):
fuser=None, fuser=None,
fastercache_counter=0, fastercache_counter=0,
fastercache_start_step=15, fastercache_start_step=15,
fastercache_device="cuda:0", fastercache_device="cuda:0"
rope_args=None
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
# norm & modulate # norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb hidden_states, encoder_hidden_states, temb
) )
# Tora Motion-guidance Fuser # Tora Motion-guidance Fuser
if video_flow_feature is not None: if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:] H, W = video_flow_feature.shape[-2:]
@ -378,7 +349,7 @@ class CogVideoXBlock(nn.Module):
attn_hidden_states, attn_encoder_hidden_states = self.attn1( attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, rope_args=rope_args image_rotary_emb=image_rotary_emb
) )
if fastercache_counter == fastercache_start_step: if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
@ -386,10 +357,18 @@ class CogVideoXBlock(nn.Module):
elif fastercache_counter > fastercache_start_step: elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# has_nan = torch.isnan(hidden_states).any()
# if has_nan:
# raise ValueError(f"hs before norm2 has nan: {has_nan}")
# has_inf = torch.isinf(hidden_states).any()
# if has_inf:
# raise ValueError(f"hs before norm2 has inf: {has_inf}")
# norm & modulate # norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb hidden_states, encoder_hidden_states, temb
@ -404,7 +383,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
#region Transformer
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
""" """
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@ -479,6 +458,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
sample_height: int = 60, sample_height: int = 60,
sample_frames: int = 49, sample_frames: int = 49,
patch_size: int = 2, patch_size: int = 2,
patch_size_t: int = 2,
temporal_compression_ratio: int = 4, temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226, max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate", activation_fn: str = "gelu-approximate",
@ -489,6 +469,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
temporal_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False, use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
): ):
super().__init__() super().__init__()
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
@ -501,12 +482,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
) )
# 1. Patch embedding # 1. Patch embedding
self.patch_embed = CogVideoX1_1PatchEmbed( self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size, patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels, in_channels=in_channels,
embed_dim=inner_dim, embed_dim=inner_dim,
text_embed_dim=text_embed_dim, text_embed_dim=text_embed_dim,
#bias=True, bias=patch_bias,
sample_width=sample_width, sample_width=sample_width,
sample_height=sample_height, sample_height=sample_height,
sample_frames=sample_frames, sample_frames=sample_frames,
@ -550,7 +532,14 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
norm_eps=norm_eps, norm_eps=norm_eps,
chunk_dim=1, chunk_dim=1,
) )
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels) if patch_size_t is None:
# For CogVideox 1.0
output_dim = patch_size * patch_size * out_channels
else:
# For CogVideoX 1.5
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = nn.Linear(inner_dim, output_dim)
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -626,44 +615,44 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self): # def fuse_qkv_projections(self):
""" # """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) # Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. # are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> # <Tip warning={true}>
This API is 🧪 experimental. # This API is 🧪 experimental.
</Tip> # </Tip>
""" # """
self.original_attn_processors = None # self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items(): # for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__): # if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") # raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors # self.original_attn_processors = self.attn_processors
for module in self.modules(): # for module in self.modules():
if isinstance(module, Attention): # if isinstance(module, Attention):
module.fuse_projections(fuse=True) # module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) # self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): # def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. # """Disables the fused QKV projection if enabled.
<Tip warning={true}> # <Tip warning={true}>
This API is 🧪 experimental. # This API is 🧪 experimental.
</Tip> # </Tip>
""" # """
if self.original_attn_processors is not None: # if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) # self.set_attn_processor(self.original_attn_processors)
def forward( def forward(
self, self,
@ -678,9 +667,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return_dict: bool = True, return_dict: bool = True,
): ):
batch_size, num_frames, channels, height, width = hidden_states.shape batch_size, num_frames, channels, height, width = hidden_states.shape
p = self.config.patch_size
print("p", p)
# 1. Time embedding # 1. Time embedding
timesteps = timestep timesteps = timestep
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
@ -691,25 +678,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
t_emb = t_emb.to(dtype=hidden_states.dtype) t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
# RoPE
seq_length = num_frames * height * width // reduce(mul, [p, p, p])
rope_T = num_frames // p
rope_H = height // p
rope_W = width // p
rope_args = {
"T": rope_T,
"H": rope_H,
"W": rope_W,
"seq_length": seq_length,
}
# 2. Patch embedding # 2. Patch embedding
p = self.config.patch_size
p_t = self.config.patch_size_t
# We know that the hidden states height and width will always be divisible by patch_size.
# But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames.
if p_t is not None:
remaining_frames = p_t - num_frames % p_t
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states) hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1] text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length] encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:] hidden_states = hidden_states[:, text_seq_length:]
if self.use_fastercache: if self.use_fastercache:
self.fastercache_counter+=1 self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
@ -754,8 +740,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p) if p_t is None:
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
output = output[:, remaining_frames:]
(bb, tt, cc, hh, ww) = output.shape (bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
@ -777,6 +770,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
output = torch.cat([output, recovered_uncond]) output = torch.cat([output, recovered_uncond])
else: else:
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
print("block", i)
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
@ -785,9 +779,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None,
fastercache_counter = self.fastercache_counter, fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device, fastercache_device = self.fastercache_device
rope_args=rope_args
) )
has_nan = torch.isnan(hidden_states).any()
if has_nan:
raise ValueError(f"block output hidden_states has nan: {has_nan}")
if (controlnet_states is not None) and (i < len(controlnet_states)): if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i] controlnet_states_block = controlnet_states[i]
@ -816,9 +812,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# Note: we use `-1` instead of `channels`: # Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) if p_t is None:
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
output = output[:, remaining_frames:]
if self.fastercache_counter >= self.fastercache_start_step + 1: if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape (bb, tt, cc, hh, ww) = output.shape

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from typing import Tuple, Union from typing import Tuple, Union, Optional
def get_1d_rotary_pos_embed( def get_1d_rotary_pos_embed(
dim: int, dim: int,
@ -123,9 +123,9 @@ def get_3d_rotary_pos_embed(
freqs = torch.cat( freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1 [freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
#freqs = freqs.view( freqs = freqs.view(
# temporal_size * grid_size_h * grid_size_w, -1 temporal_size * grid_size_h * grid_size_w, -1
#) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
@ -236,16 +236,18 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb return emb
class CogVideoX1_1PatchEmbed(nn.Module): class CogVideoXPatchEmbed(nn.Module):
def __init__( def __init__(
self, self,
patch_size: int = 2, patch_size: int = 2,
patch_size_t: Optional[int] = None,
in_channels: int = 16, in_channels: int = 16,
embed_dim: int = 1920, embed_dim: int = 1920,
text_embed_dim: int = 4096, text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90, sample_width: int = 90,
sample_height: int = 60, sample_height: int = 60,
sample_frames: int = 81, sample_frames: int = 49,
temporal_compression_ratio: int = 4, temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226, max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875, spatial_interpolation_scale: float = 1.875,
@ -255,8 +257,8 @@ class CogVideoX1_1PatchEmbed(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
# Adjust patch_size to handle three dimensions self.patch_size = patch_size
self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width) self.patch_size_t = patch_size_t
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.sample_height = sample_height self.sample_height = sample_height
self.sample_width = sample_width self.sample_width = sample_width
@ -268,8 +270,15 @@ class CogVideoX1_1PatchEmbed(nn.Module):
self.use_positional_embeddings = use_positional_embeddings self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings self.use_learned_positional_embeddings = use_learned_positional_embeddings
# Use Linear layer for projection if patch_size_t is None:
self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim) # CogVideoX 1.0 checkpoints
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
else:
# CogVideoX 1.5 checkpoints
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
self.text_proj = nn.Linear(text_embed_dim, embed_dim) self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings or use_learned_positional_embeddings: if use_positional_embeddings or use_learned_positional_embeddings:
@ -278,8 +287,8 @@ class CogVideoX1_1PatchEmbed(nn.Module):
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size[1] post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size[2] post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames num_patches = post_patch_height * post_patch_width * post_time_compression_frames
@ -291,44 +300,46 @@ class CogVideoX1_1PatchEmbed(nn.Module):
self.temporal_interpolation_scale, self.temporal_interpolation_scale,
) )
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False) joint_pos_embedding = torch.zeros(
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
return joint_pos_embedding return joint_pos_embedding
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
""" r"""
Args: Args:
text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim). text_embeds (`torch.Tensor`):
image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width). Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
""" """
text_embeds = self.text_proj(text_embeds) text_embeds = self.text_proj(text_embeds)
first_frame = image_embeds[:, 0:1, :, :, :]
duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width)
# Copy the first frames, for t_patch
image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous()
image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1)
rope_patch_t = num_frames // self.patch_size[0] batch_size, num_frames, channels, height, width = image_embeds.shape
rope_patch_h = height // self.patch_size[1]
rope_patch_w = width // self.patch_size[2]
image_embeds = image_embeds.view( if self.patch_size_t is None:
batch, image_embeds = image_embeds.reshape(-1, channels, height, width)
rope_patch_t, self.patch_size[0], image_embeds = self.proj(image_embeds)
rope_patch_h, self.patch_size[1], image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
rope_patch_w, self.patch_size[2], image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
channels image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
) else:
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() p = self.patch_size
image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1) p_t = self.patch_size_t
image_embeds = self.proj(image_embeds)
# Concatenate text and image embeddings image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
# Add positional embeddings if applicable
if self.use_positional_embeddings or self.use_learned_positional_embeddings: if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError( raise ValueError(
@ -339,9 +350,9 @@ class CogVideoX1_1PatchEmbed(nn.Module):
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if ( if (
self.sample_height != height self.sample_height != height
or self.sample_width != width or self.sample_width != width
or self.sample_frames != pre_time_compression_frames or self.sample_frames != pre_time_compression_frames
): ):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
@ -350,4 +361,5 @@ class CogVideoX1_1PatchEmbed(nn.Module):
embeds = embeds + pos_embedding embeds = embeds + pos_embedding
return embeds return embeds

View File

@ -21,6 +21,7 @@ import torch.nn.functional as F
import math import math
from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel from diffusers.models import AutoencoderKLCogVideoX#, CogVideoXTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
@ -115,7 +116,7 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
return timesteps, num_inference_steps return timesteps, num_inference_steps
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r""" r"""
Pipeline for text-to-video generation using CogVideoX. Pipeline for text-to-video generation using CogVideoX.
@ -298,18 +299,18 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights return weights
def fuse_qkv_projections(self) -> None: # def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections.""" # r"""Enables fused QKV projections."""
self.fusing_transformer = True # self.fusing_transformer = True
self.transformer.fuse_qkv_projections() # self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None: # def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled.""" # r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer: # if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") # logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else: # else:
self.transformer.unfuse_qkv_projections() # self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False # self.fusing_transformer = False
def _prepare_rotary_positional_embeddings( def _prepare_rotary_positional_embeddings(
self, self,
@ -322,8 +323,12 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
@ -332,7 +337,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=base_num_frames,
use_real=True, use_real=True,
) )