mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-10 05:14:22 +08:00
fix typos in STG attention
This commit is contained in:
parent
0ea77bc63f
commit
c9c9b50e40
@ -130,8 +130,8 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0)
|
hidden_states_uncond, hidden_states_cond, hidden_states_perturb = hidden_states.chunk(3, dim=0)
|
||||||
encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0)
|
encoder_hidden_states_uncond, encoder_hidden_states_cond, encoder_hidden_states_perturb = encoder_hidden_states.chunk(3, dim=0)
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states_uncond, hidden_states], dim=0)
|
hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond], dim=0)
|
||||||
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states], dim=0)
|
encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond], dim=0)
|
||||||
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
@ -197,27 +197,27 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
hidden_states *= feta_scores
|
hidden_states *= feta_scores
|
||||||
|
|
||||||
if self.stg_mode == "STG-A":
|
if self.stg_mode == "STG-A":
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states_perturb.size(1)
|
||||||
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
hidden_states_perturb = torch.cat([encoder_hidden_states_perturb, hidden_states_perturb], dim=1)
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
batch_size, sequence_length, _ = (
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
hidden_states_perturb.shape if encoder_hidden_states_perturb is None else encoder_hidden_states_perturb.shape
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
#if attention_mask is not None:
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
||||||
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
hidden_states_perturb = hidden_states_perturb.to(attn.to_q.weight.dtype)
|
||||||
|
|
||||||
if not "fused" in self.attention_mode:
|
if not "fused" in self.attention_mode:
|
||||||
query_perturb = attn.to_q(hidden_states)
|
query_perturb = attn.to_q(hidden_states_perturb)
|
||||||
key_perturb = attn.to_k(hidden_states)
|
key_perturb = attn.to_k(hidden_states_perturb)
|
||||||
value_perturb = attn.to_v(hidden_states)
|
value_perturb = attn.to_v(hidden_states_perturb)
|
||||||
else:
|
else:
|
||||||
qkv = attn.to_qkv(hidden_states)
|
qkv = attn.to_qkv(hidden_states_perturb)
|
||||||
split_size = qkv.shape[-1] // 3
|
split_size = qkv.shape[-1] // 3
|
||||||
query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1)
|
query_perturb, key_perturb, value_perturb = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
@ -254,19 +254,19 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.attention_mode != "comfy":
|
if self.attention_mode != "comfy":
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states_perturb = hidden_states_perturb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states_perturb = attn.to_out[0](hidden_states_perturb)
|
||||||
# dropout
|
# dropout
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
hidden_states_perturb = attn.to_out[1](hidden_states_perturb)
|
||||||
|
|
||||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
encoder_hidden_states_perturb, hidden_states_perturb = hidden_states_perturb.split(
|
||||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0)
|
hidden_states = torch.cat([hidden_states, hidden_states_perturb], dim=0)
|
||||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_perturb], dim=0)
|
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_perturb], dim=0)
|
||||||
|
|
||||||
return hidden_states, encoder_hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user