From c9c9b50e40925bb8fcfa736099f3d65b9ddefb15 Mon Sep 17 00:00:00 2001 From: zhilemann <144045746+zhilemann@users.noreply.github.com> Date: Sun, 22 Dec 2024 03:32:22 +0300 Subject: [PATCH] fix typos in STG attention --- custom_cogvideox_transformer_3d.py | 38 +++++++++++++++--------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index d66e369..d56deb4 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -130,8 +130,8 @@ class CogVideoXAttnProcessor2_0: hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0) encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0) - hidden_states = torch.cat([hidden_states_uncond, hidden_states], dim=0) - encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states], dim=0) + hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond], dim=0) + encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond], dim=0) text_seq_length = encoder_hidden_states.size(1) @@ -197,27 +197,27 @@ class CogVideoXAttnProcessor2_0: hidden_states *= feta_scores if self.stg_mode == "STG-A": - text_seq_length = encoder_hidden_states.size(1) + text_seq_length = encoder_hidden_states_perturb.size(1) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states_perturb = torch.cat([encoder_hidden_states_perturb, hidden_states_perturb], dim=1) batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + hidden_states_perturb.shape if encoder_hidden_states_perturb is None else encoder_hidden_states_perturb.shape ) - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + #if attention_mask is not None: + # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16: - hidden_states = hidden_states.to(attn.to_q.weight.dtype) + hidden_states_perturb = hidden_states_perturb.to(attn.to_q.weight.dtype) if not "fused" in self.attention_mode: - query_perturb = attn.to_q(hidden_states) - key_perturb = attn.to_k(hidden_states) - value_perturb = attn.to_v(hidden_states) + query_perturb = attn.to_q(hidden_states_perturb) + key_perturb = attn.to_k(hidden_states_perturb) + value_perturb = attn.to_v(hidden_states_perturb) else: - qkv = attn.to_qkv(hidden_states) + qkv = attn.to_qkv(hidden_states_perturb) split_size = qkv.shape[-1] // 3 query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1) @@ -254,19 +254,19 @@ class CogVideoXAttnProcessor2_0: ) if self.attention_mode != "comfy": - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_perturb = hidden_states_perturb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # linear proj - hidden_states = attn.to_out[0](hidden_states) + hidden_states_perturb = attn.to_out[0](hidden_states_perturb) # dropout - hidden_states = attn.to_out[1](hidden_states) + hidden_states_perturb = attn.to_out[1](hidden_states_perturb) - encoder_hidden_states, hidden_states = hidden_states.split( + encoder_hidden_states_perturb, hidden_states_perturb = hidden_states_perturb.split( [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 ) - hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0) - encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_perturb], dim=0) + hidden_states = torch.cat([hidden_states, hidden_states_perturb], dim=0) + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_perturb], dim=0) return hidden_states, encoder_hidden_states