From b0eabeba2427422719d181272efb25436e53615b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 19 Nov 2024 20:18:13 +0200 Subject: [PATCH] fix comfy attention output shape --- custom_cogvideox_transformer_3d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index e2aebba..05bed34 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -127,15 +127,15 @@ class CogVideoXAttnProcessor2_0: if attention_mode == "sageattn" or attention_mode == "fused_sageattn": hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) elif attention_mode == "sdpa": hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) elif attention_mode == "comfy": hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - + # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout