mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
fix comfy attention output shape
This commit is contained in:
parent
822cb4ee1c
commit
b0eabeba24
@ -127,15 +127,15 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
|
|
||||||
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
||||||
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
elif attention_mode == "sdpa":
|
elif attention_mode == "sdpa":
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
elif attention_mode == "comfy":
|
elif attention_mode == "comfy":
|
||||||
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
|
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
# dropout
|
# dropout
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user