Update custom_cogvideox_transformer_3d.py

This commit is contained in:
kijai 2024-11-17 22:23:40 +02:00
parent e70da23ac2
commit 6f9e4ff647

View File

@ -125,7 +125,7 @@ class CogVideoXAttnProcessor2_0:
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn":
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(