diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index a228219..50b0f25 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -125,7 +125,7 @@ class CogVideoXAttnProcessor2_0: if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - if attention_mode == "sageattn": + if attention_mode == "sageattn" or attention_mode == "fused_sageattn": hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) else: hidden_states = F.scaled_dot_product_attention(