mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
fix typos in STG attention
This commit is contained in:
parent
0ea77bc63f
commit
c9c9b50e40
@ -130,8 +130,8 @@ class CogVideoXAttnProcessor2_0:
|
||||
hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0)
|
||||
encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_uncond, hidden_states], dim=0)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states], dim=0)
|
||||
hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond], dim=0)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond], dim=0)
|
||||
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
@ -197,27 +197,27 @@ class CogVideoXAttnProcessor2_0:
|
||||
hidden_states *= feta_scores
|
||||
|
||||
if self.stg_mode == "STG-A":
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
text_seq_length = encoder_hidden_states_perturb.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states_perturb = torch.cat([encoder_hidden_states_perturb, hidden_states_perturb], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
hidden_states_perturb.shape if encoder_hidden_states_perturb is None else encoder_hidden_states_perturb.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
#if attention_mask is not None:
|
||||
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
||||
hidden_states_perturb = hidden_states_perturb.to(attn.to_q.weight.dtype)
|
||||
|
||||
if not "fused" in self.attention_mode:
|
||||
query_perturb = attn.to_q(hidden_states)
|
||||
key_perturb = attn.to_k(hidden_states)
|
||||
value_perturb = attn.to_v(hidden_states)
|
||||
query_perturb = attn.to_q(hidden_states_perturb)
|
||||
key_perturb = attn.to_k(hidden_states_perturb)
|
||||
value_perturb = attn.to_v(hidden_states_perturb)
|
||||
else:
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
qkv = attn.to_qkv(hidden_states_perturb)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
@ -254,19 +254,19 @@ class CogVideoXAttnProcessor2_0:
|
||||
)
|
||||
|
||||
if self.attention_mode != "comfy":
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_perturb = hidden_states_perturb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states_perturb = attn.to_out[0](hidden_states_perturb)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
hidden_states_perturb = attn.to_out[1](hidden_states_perturb)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
encoder_hidden_states_perturb, hidden_states_perturb = hidden_states_perturb.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_perturb], dim=0)
|
||||
hidden_states = torch.cat([hidden_states, hidden_states_perturb], dim=0)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_perturb], dim=0)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user