diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index c751d13..9bbd87b 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -73,6 +73,7 @@ class CogVideoXAttnProcessor2_0: raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @torch.compiler.disable() + def __call__( self, attn: Attention, @@ -115,14 +116,16 @@ class CogVideoXAttnProcessor2_0: query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - if SAGEATTN_IS_AVAILABLE: - hidden_states = sageattn(query, key, value, is_causal=False) - else: - hidden_states = F.scaled_dot_product_attention( + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + #if SAGEATTN_IS_AVAILABLE: + # hidden_states = sageattn(query, key, value, is_causal=False) + #else: + hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + if torch.isinf(hidden_states).any(): + raise ValueError(f"hidden_states after dot product has inf") hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -136,81 +139,6 @@ class CogVideoXAttnProcessor2_0: ) return hidden_states, encoder_hidden_states - -# class FusedCogVideoXAttnProcessor2_0: -# r""" -# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on -# query and key vectors, but does not include spatial normalization. -# """ - -# def __init__(self): -# if not hasattr(F, "scaled_dot_product_attention"): -# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") -# @torch.compiler.disable() -# def __call__( -# self, -# attn: Attention, -# hidden_states: torch.Tensor, -# encoder_hidden_states: torch.Tensor, -# attention_mask: Optional[torch.Tensor] = None, -# image_rotary_emb: Optional[torch.Tensor] = None, -# ) -> torch.Tensor: -# print("FusedCogVideoXAttnProcessor2_0") -# text_seq_length = encoder_hidden_states.size(1) - -# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - -# batch_size, sequence_length, _ = ( -# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape -# ) - -# if attention_mask is not None: -# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) -# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - -# qkv = attn.to_qkv(hidden_states) -# split_size = qkv.shape[-1] // 3 -# query, key, value = torch.split(qkv, split_size, dim=-1) - -# inner_dim = key.shape[-1] -# head_dim = inner_dim // attn.heads - -# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) -# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) -# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - -# if attn.norm_q is not None: -# query = attn.norm_q(query) -# if attn.norm_k is not None: -# key = attn.norm_k(key) - -# # Apply RoPE if needed -# if image_rotary_emb is not None: -# from diffusers.models.embeddings import apply_rotary_emb - -# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) -# if not attn.is_cross_attention: -# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - -# if SAGEATTN_IS_AVAILABLE: -# hidden_states = sageattn(query, key, value, is_causal=False) -# else: -# hidden_states = F.scaled_dot_product_attention( -# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False -# ) - -# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - -# # linear proj -# hidden_states = attn.to_out[0](hidden_states) -# # dropout -# hidden_states = attn.to_out[1](hidden_states) - -# encoder_hidden_states, hidden_states = hidden_states.split( -# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 -# ) -# return hidden_states, encoder_hidden_states - #region Blocks @maybe_allow_in_graph class CogVideoXBlock(nn.Module): @@ -270,6 +198,7 @@ class CogVideoXBlock(nn.Module): # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + self.attn1 = Attention( query_dim=dim, @@ -308,11 +237,14 @@ class CogVideoXBlock(nn.Module): fastercache_start_step=15, fastercache_device="cuda:0" ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) + # Tora Motion-guidance Fuser if video_flow_feature is not None: H, W = video_flow_feature.shape[-2:] @@ -347,19 +279,12 @@ class CogVideoXBlock(nn.Module): elif fastercache_counter > fastercache_start_step: self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) - - + hidden_states = hidden_states + gate_msa * attn_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states - # has_nan = torch.isnan(hidden_states).any() - # if has_nan: - # raise ValueError(f"hs before norm2 has nan: {has_nan}") - # has_inf = torch.isinf(hidden_states).any() - # if has_inf: - # raise ValueError(f"hs before norm2 has inf: {has_inf}") - # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb ) @@ -604,45 +529,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 - # def fuse_qkv_projections(self): - # """ - # Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - # are fused. For cross-attention modules, key and value projection matrices are fused. - - # - - # This API is 🧪 experimental. - - # - # """ - # self.original_attn_processors = None - - # for _, attn_processor in self.attn_processors.items(): - # if "Added" in str(attn_processor.__class__.__name__): - # raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - # self.original_attn_processors = self.attn_processors - - # for module in self.modules(): - # if isinstance(module, Attention): - # module.fuse_projections(fuse=True) - - # self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - # def unfuse_qkv_projections(self): - # """Disables the fused QKV projection if enabled. - - # - - # This API is 🧪 experimental. - - # - - # """ - # if self.original_attn_processors is not None: - # self.set_attn_processor(self.original_attn_processors) def forward( self, @@ -679,8 +565,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) + text_seq_length = encoder_hidden_states.shape[1] encoder_hidden_states = hidden_states[:, :text_seq_length] @@ -760,7 +648,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): output = torch.cat([output, recovered_uncond]) else: for i, block in enumerate(self.transformer_blocks): - #print("block", i) hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/model_loading.py b/model_loading.py index 91c6ed3..00dfcee 100644 --- a/model_loading.py +++ b/model_loading.py @@ -261,7 +261,7 @@ class DownloadAndLoadCogVideoModel: if "CogVideoXBlock" in str(block): pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") - + elif compile == "onediff": from onediffx import compile_pipe @@ -274,7 +274,7 @@ class DownloadAndLoadCogVideoModel: ignores=["vae"], fuse_qkv_projections=True if pab_config is None else False, ) - + pipeline = { "pipe": pipe, "dtype": dtype, @@ -453,6 +453,8 @@ class DownloadAndLoadCogVideoGGUFModel: if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() + + pipeline = { "pipe": pipe, "dtype": vae_dtype, diff --git a/nodes.py b/nodes.py index 4e4ce6f..fe4d367 100644 --- a/nodes.py +++ b/nodes.py @@ -861,6 +861,7 @@ class CogVideoSampler: pipe.transformer.fastercache_counter = 0 autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 + autocastcondition = False ##todo autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: latents = pipeline["pipe"](