fix typos in STG attention

This commit is contained in:
zhilemann 2024-12-22 03:32:22 +03:00 committed by GitHub
parent 0ea77bc63f
commit c9c9b50e40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -130,8 +130,8 @@ class CogVideoXAttnProcessor2_0:
hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0)
encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0)
hidden_states = torch.cat([hidden_states_uncond, hidden_states], dim=0)
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states], dim=0)
hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond], dim=0)
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond], dim=0)
text_seq_length = encoder_hidden_states.size(1)
@ -197,27 +197,27 @@ class CogVideoXAttnProcessor2_0:
hidden_states *= feta_scores
if self.stg_mode == "STG-A":
text_seq_length = encoder_hidden_states.size(1)
text_seq_length = encoder_hidden_states_perturb.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states_perturb = torch.cat([encoder_hidden_states_perturb, hidden_states_perturb], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
hidden_states_perturb.shape if encoder_hidden_states_perturb is None else encoder_hidden_states_perturb.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
#if attention_mask is not None:
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
hidden_states_perturb = hidden_states_perturb.to(attn.to_q.weight.dtype)
if not "fused" in self.attention_mode:
query_perturb = attn.to_q(hidden_states)
key_perturb = attn.to_k(hidden_states)
value_perturb = attn.to_v(hidden_states)
query_perturb = attn.to_q(hidden_states_perturb)
key_perturb = attn.to_k(hidden_states_perturb)
value_perturb = attn.to_v(hidden_states_perturb)
else:
qkv = attn.to_qkv(hidden_states)
qkv = attn.to_qkv(hidden_states_perturb)
split_size = qkv.shape[-1] // 3
query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1)
@ -254,19 +254,19 @@ class CogVideoXAttnProcessor2_0:
)
if self.attention_mode != "comfy":
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_perturb = hidden_states_perturb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states_perturb = attn.to_out[0](hidden_states_perturb)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states_perturb = attn.to_out[1](hidden_states_perturb)
encoder_hidden_states, hidden_states = hidden_states.split(
encoder_hidden_states_perturb, hidden_states_perturb = hidden_states_perturb.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0)
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_perturb], dim=0)
hidden_states = torch.cat([hidden_states, hidden_states_perturb], dim=0)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_perturb], dim=0)
return hidden_states, encoder_hidden_states