fix comfy attention output shape

This commit is contained in:
kijai 2024-11-19 20:18:13 +02:00
parent 822cb4ee1c
commit b0eabeba24

View File

@ -127,15 +127,15 @@ class CogVideoXAttnProcessor2_0:
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "sdpa":
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "comfy":
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout