diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 05bed34..27f2d5f 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -128,7 +128,7 @@ class CogVideoXAttnProcessor2_0: if attention_mode == "sageattn" or attention_mode == "fused_sageattn": hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - elif attention_mode == "sdpa": + elif attention_mode == "sdpa" or attention_mode == "fused_sdpa": hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -751,4 +751,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - \ No newline at end of file +