diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py
index c751d13..9bbd87b 100644
--- a/custom_cogvideox_transformer_3d.py
+++ b/custom_cogvideox_transformer_3d.py
@@ -73,6 +73,7 @@ class CogVideoXAttnProcessor2_0:
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
+
def __call__(
self,
attn: Attention,
@@ -115,14 +116,16 @@ class CogVideoXAttnProcessor2_0:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
- if SAGEATTN_IS_AVAILABLE:
- hidden_states = sageattn(query, key, value, is_causal=False)
- else:
- hidden_states = F.scaled_dot_product_attention(
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ #if SAGEATTN_IS_AVAILABLE:
+ # hidden_states = sageattn(query, key, value, is_causal=False)
+ #else:
+ hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
+ if torch.isinf(hidden_states).any():
+ raise ValueError(f"hidden_states after dot product has inf")
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -136,81 +139,6 @@ class CogVideoXAttnProcessor2_0:
)
return hidden_states, encoder_hidden_states
-
-# class FusedCogVideoXAttnProcessor2_0:
-# r"""
-# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
-# query and key vectors, but does not include spatial normalization.
-# """
-
-# def __init__(self):
-# if not hasattr(F, "scaled_dot_product_attention"):
-# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-# @torch.compiler.disable()
-# def __call__(
-# self,
-# attn: Attention,
-# hidden_states: torch.Tensor,
-# encoder_hidden_states: torch.Tensor,
-# attention_mask: Optional[torch.Tensor] = None,
-# image_rotary_emb: Optional[torch.Tensor] = None,
-# ) -> torch.Tensor:
-# print("FusedCogVideoXAttnProcessor2_0")
-# text_seq_length = encoder_hidden_states.size(1)
-
-# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
-# batch_size, sequence_length, _ = (
-# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-# )
-
-# if attention_mask is not None:
-# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
-# qkv = attn.to_qkv(hidden_states)
-# split_size = qkv.shape[-1] // 3
-# query, key, value = torch.split(qkv, split_size, dim=-1)
-
-# inner_dim = key.shape[-1]
-# head_dim = inner_dim // attn.heads
-
-# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-# if attn.norm_q is not None:
-# query = attn.norm_q(query)
-# if attn.norm_k is not None:
-# key = attn.norm_k(key)
-
-# # Apply RoPE if needed
-# if image_rotary_emb is not None:
-# from diffusers.models.embeddings import apply_rotary_emb
-
-# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
-# if not attn.is_cross_attention:
-# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
-# if SAGEATTN_IS_AVAILABLE:
-# hidden_states = sageattn(query, key, value, is_causal=False)
-# else:
-# hidden_states = F.scaled_dot_product_attention(
-# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
-# )
-
-# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-
-# # linear proj
-# hidden_states = attn.to_out[0](hidden_states)
-# # dropout
-# hidden_states = attn.to_out[1](hidden_states)
-
-# encoder_hidden_states, hidden_states = hidden_states.split(
-# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
-# )
-# return hidden_states, encoder_hidden_states
-
#region Blocks
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
@@ -270,6 +198,7 @@ class CogVideoXBlock(nn.Module):
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
self.attn1 = Attention(
query_dim=dim,
@@ -308,11 +237,14 @@ class CogVideoXBlock(nn.Module):
fastercache_start_step=15,
fastercache_device="cuda:0"
) -> torch.Tensor:
+
text_seq_length = encoder_hidden_states.size(1)
+
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
+
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
@@ -347,19 +279,12 @@ class CogVideoXBlock(nn.Module):
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
-
-
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
- # has_nan = torch.isnan(hidden_states).any()
- # if has_nan:
- # raise ValueError(f"hs before norm2 has nan: {has_nan}")
- # has_inf = torch.isinf(hidden_states).any()
- # if has_inf:
- # raise ValueError(f"hs before norm2 has inf: {has_inf}")
-
# norm & modulate
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
@@ -604,45 +529,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
- # def fuse_qkv_projections(self):
- # """
- # Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- # are fused. For cross-attention modules, key and value projection matrices are fused.
-
- #
-
- # This API is 🧪 experimental.
-
- #
- # """
- # self.original_attn_processors = None
-
- # for _, attn_processor in self.attn_processors.items():
- # if "Added" in str(attn_processor.__class__.__name__):
- # raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- # self.original_attn_processors = self.attn_processors
-
- # for module in self.modules():
- # if isinstance(module, Attention):
- # module.fuse_projections(fuse=True)
-
- # self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- # def unfuse_qkv_projections(self):
- # """Disables the fused QKV projection if enabled.
-
- #
-
- # This API is 🧪 experimental.
-
- #
-
- # """
- # if self.original_attn_processors is not None:
- # self.set_attn_processor(self.original_attn_processors)
def forward(
self,
@@ -679,8 +565,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
+
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
@@ -760,7 +648,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
output = torch.cat([output, recovered_uncond])
else:
for i, block in enumerate(self.transformer_blocks):
- #print("block", i)
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
diff --git a/model_loading.py b/model_loading.py
index 91c6ed3..00dfcee 100644
--- a/model_loading.py
+++ b/model_loading.py
@@ -261,7 +261,7 @@ class DownloadAndLoadCogVideoModel:
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
-
+
elif compile == "onediff":
from onediffx import compile_pipe
@@ -274,7 +274,7 @@ class DownloadAndLoadCogVideoModel:
ignores=["vae"],
fuse_qkv_projections=True if pab_config is None else False,
)
-
+
pipeline = {
"pipe": pipe,
"dtype": dtype,
@@ -453,6 +453,8 @@ class DownloadAndLoadCogVideoGGUFModel:
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
+
+
pipeline = {
"pipe": pipe,
"dtype": vae_dtype,
diff --git a/nodes.py b/nodes.py
index 4e4ce6f..fe4d367 100644
--- a/nodes.py
+++ b/nodes.py
@@ -861,6 +861,7 @@ class CogVideoSampler:
pipe.transformer.fastercache_counter = 0
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
+ autocastcondition = False ##todo
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
with autocast_context:
latents = pipeline["pipe"](